diff --git a/build.gradle b/build.gradle index b0c88dbc8..b7bc40318 100644 --- a/build.gradle +++ b/build.gradle @@ -139,7 +139,7 @@ dependencies { implementation group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0' implementation group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0' implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0' - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.12.0' + implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0' implementation "org.jacoco:org.jacoco.agent:0.8.5" implementation ("org.jacoco:org.jacoco.ant:0.8.5") { @@ -148,6 +148,9 @@ dependencies { exclude group: 'org.ow2.asm', module: 'asm-tree' } + // used for output encoding of config descriptions + implementation group: 'org.owasp.encoder' , name: 'encoder', version: '1.2.3' + testImplementation group: 'pl.pragmatists', name: 'JUnitParams', version: '1.1.1' testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.3.1' testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.0.1' @@ -403,6 +406,7 @@ testClusters.integTest { @Override File getAsFile() { return configurations.zipArchive.asFileTree.getSingleFile() + //return fileTree("src/test/resources/job-scheduler").getSingleFile() } } } diff --git a/src/main/java/org/opensearch/ad/ADEntityProfileRunner.java b/src/main/java/org/opensearch/ad/ADEntityProfileRunner.java new file mode 100644 index 000000000..897c853f1 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ADEntityProfileRunner.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.transport.ADEntityProfileAction; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.EntityProfileRunner; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ADEntityProfileRunner extends EntityProfileRunner { + + public ADEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ) { + super( + client, + clientUtil, + xContentRegistry, + requiredSamples, + AnomalyDetector::parse, + ADNumericSetting.maxCategoricalFields(), + AnalysisType.AD, + ADEntityProfileAction.INSTANCE, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + AnomalyResult.DETECTOR_ID_FIELD + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ADJobProcessor.java b/src/main/java/org/opensearch/ad/ADJobProcessor.java new file mode 100644 index 000000000..4492d3708 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ADJobProcessor.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad; + +import java.time.Instant; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADProfileAction; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ADJobProcessor extends + JobProcessor { + + private static final Logger log = LogManager.getLogger(ADJobProcessor.class); + + private static ADJobProcessor INSTANCE; + + public static ADJobProcessor getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ADJobProcessor.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new ADJobProcessor(); + return INSTANCE; + } + } + + private ADJobProcessor() { + // Singleton class, use getJobRunnerInstance method instead of constructor + super(AnalysisType.AD, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, AnomalyResultAction.INSTANCE); + } + + public void registerSettings(Settings settings) { + super.registerSettings(settings, AnomalyDetectorSettings.AD_MAX_RETRY_FOR_END_RUN_EXCEPTION); + } + + @Override + protected ResultRequest createResultRequest(String configId, long start, long end) { + return new AnomalyResultRequest(configId, start, end); + } + + @Override + protected void validateResultIndexAndRunJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteADResultResponseRecorder recorder, + Config detector + ) { + String resultIndex = jobParameter.getCustomResultIndex(); + if (resultIndex == null) { + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + return; + } + ActionListener listener = ActionListener.wrap(r -> { log.debug("Custom index is valid"); }, e -> { + Exception exception = new EndRunException(configId, e.getMessage(), false); + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, exception, recorder, detector); + }); + indexManagement.validateCustomIndexForBackendJob(resultIndex, configId, user, roles, () -> { + listener.onResponse(true); + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + }, listener); + } +} diff --git a/src/main/java/org/opensearch/ad/ADTaskProfileRunner.java b/src/main/java/org/opensearch/ad/ADTaskProfileRunner.java new file mode 100644 index 000000000..6bad4935c --- /dev/null +++ b/src/main/java/org/opensearch/ad/ADTaskProfileRunner.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.transport.ADTaskProfileAction; +import org.opensearch.ad.transport.ADTaskProfileNodeResponse; +import org.opensearch.ad.transport.ADTaskProfileRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.TaskProfileRunner; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.model.EntityTaskProfile; + +public class ADTaskProfileRunner implements TaskProfileRunner { + public final Logger logger = LogManager.getLogger(ADTaskProfileRunner.class); + + private final HashRing hashRing; + private final Client client; + + public ADTaskProfileRunner(HashRing hashRing, Client client) { + this.hashRing = hashRing; + this.client = client; + } + + @Override + public void getTaskProfile(ADTask configLevelTask, ActionListener listener) { + String detectorId = configLevelTask.getConfigId(); + + hashRing.getAllEligibleDataNodesWithKnownVersion(dataNodes -> { + ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); + client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + listener.onFailure(response.failures().get(0)); + return; + } + + List adEntityTaskProfiles = new ArrayList<>(); + ADTaskProfile detectorTaskProfile = new ADTaskProfile(configLevelTask); + for (ADTaskProfileNodeResponse node : response.getNodes()) { + ADTaskProfile taskProfile = node.getAdTaskProfile(); + if (taskProfile != null) { + if (taskProfile.getNodeId() != null) { + // HC detector: task profile from coordinating node + // Single entity detector: task profile from worker node + detectorTaskProfile.setTaskId(taskProfile.getTaskId()); + detectorTaskProfile.setShingleSize(taskProfile.getShingleSize()); + detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates()); + detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained()); + detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize()); + detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes()); + detectorTaskProfile.setNodeId(taskProfile.getNodeId()); + detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount()); + detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots()); + detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount()); + detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount()); + detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities()); + detectorTaskProfile.setTaskType(taskProfile.getTaskType()); + } + if (taskProfile.getEntityTaskProfiles() != null) { + adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles()); + } + } + } + if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) { + detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles); + } + listener.onResponse(detectorTaskProfile); + }, e -> { + logger.error("Failed to get task profile for task " + configLevelTask.getTaskId(), e); + listener.onFailure(e); + })); + }, listener); + + } + +} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java deleted file mode 100644 index 98135e1ee..000000000 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ /dev/null @@ -1,653 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad; - -import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; -import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; - -import java.io.IOException; -import java.time.Instant; -import java.util.List; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorService; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyResultAction; -import org.opensearch.ad.transport.AnomalyResultRequest; -import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.AnomalyResultTransportAction; -import org.opensearch.client.Client; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.commons.InjectSecurity; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.jobscheduler.spi.JobExecutionContext; -import org.opensearch.jobscheduler.spi.LockModel; -import org.opensearch.jobscheduler.spi.ScheduledJobParameter; -import org.opensearch.jobscheduler.spi.ScheduledJobRunner; -import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; -import org.opensearch.jobscheduler.spi.utils.LockService; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.InternalFailure; -import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.model.TaskState; -import org.opensearch.timeseries.util.SecurityUtil; - -import com.google.common.base.Throwables; - -/** - * JobScheduler will call AD job runner to get anomaly result periodically - */ -public class AnomalyDetectorJobRunner implements ScheduledJobRunner { - private static final Logger log = LogManager.getLogger(AnomalyDetectorJobRunner.class); - private static AnomalyDetectorJobRunner INSTANCE; - private Settings settings; - private int maxRetryForEndRunException; - private Client client; - private ThreadPool threadPool; - private ConcurrentHashMap detectorEndRunExceptionCount; - private ADIndexManagement anomalyDetectionIndices; - private ADTaskManager adTaskManager; - private NodeStateManager nodeStateManager; - private ExecuteADResultResponseRecorder recorder; - - public static AnomalyDetectorJobRunner getJobRunnerInstance() { - if (INSTANCE != null) { - return INSTANCE; - } - synchronized (AnomalyDetectorJobRunner.class) { - if (INSTANCE != null) { - return INSTANCE; - } - INSTANCE = new AnomalyDetectorJobRunner(); - return INSTANCE; - } - } - - private AnomalyDetectorJobRunner() { - // Singleton class, use getJobRunnerInstance method instead of constructor - this.detectorEndRunExceptionCount = new ConcurrentHashMap<>(); - } - - public void setClient(Client client) { - this.client = client; - } - - public void setThreadPool(ThreadPool threadPool) { - this.threadPool = threadPool; - } - - public void setSettings(Settings settings) { - this.settings = settings; - this.maxRetryForEndRunException = AnomalyDetectorSettings.AD_MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings); - } - - public void setAdTaskManager(ADTaskManager adTaskManager) { - this.adTaskManager = adTaskManager; - } - - public void setAnomalyDetectionIndices(ADIndexManagement anomalyDetectionIndices) { - this.anomalyDetectionIndices = anomalyDetectionIndices; - } - - public void setNodeStateManager(NodeStateManager nodeStateManager) { - this.nodeStateManager = nodeStateManager; - } - - public void setExecuteADResultResponseRecorder(ExecuteADResultResponseRecorder recorder) { - this.recorder = recorder; - } - - @Override - public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext context) { - String detectorId = scheduledJobParameter.getName(); - log.info("Start to run AD job {}", detectorId); - adTaskManager.refreshRealtimeJobRunTime(detectorId); - if (!(scheduledJobParameter instanceof Job)) { - throw new IllegalArgumentException( - "Job parameter is not instance of Job, type: " + scheduledJobParameter.getClass().getCanonicalName() - ); - } - Job jobParameter = (Job) scheduledJobParameter; - Instant executionStartTime = Instant.now(); - IntervalSchedule schedule = (IntervalSchedule) jobParameter.getSchedule(); - Instant detectionStartTime = executionStartTime.minus(schedule.getInterval(), schedule.getUnit()); - - final LockService lockService = context.getLockService(); - - Runnable runnable = () -> { - try { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId)); - return; - } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - - if (jobParameter.getLockDurationSeconds() != null) { - lockService - .acquireLock( - jobParameter, - context, - ActionListener - .wrap( - lock -> runAdJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - recorder, - detector - ), - exception -> { - indexAnomalyResultException( - jobParameter, - lockService, - null, - detectionStartTime, - executionStartTime, - exception, - false, - recorder, - detector - ); - throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); - } - ) - ); - } else { - log.warn("Can't get lock for AD job: " + detectorId); - } - - }, e -> log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), e))); - } catch (Exception e) { - // os log won't show anything if there is an exception happens (maybe due to running on a ExecutorService) - // we at least log the error. - log.error("Can't start AD job: " + detectorId, e); - throw e; - } - }; - - ExecutorService executor = threadPool.executor(AD_THREAD_POOL_NAME); - executor.submit(runnable); - } - - /** - * Get anomaly result, index result or handle exception if failed. - * - * @param jobParameter scheduled job parameter - * @param lockService lock service - * @param lock lock to run job - * @param detectionStartTime detection start time - * @param executionStartTime detection end time - * @param recorder utility to record job execution result - * @param detector associated detector accessor - */ - protected void runAdJob( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - if (lock == null) { - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - "Can't run AD job due to null lock", - false, - recorder, - detector - ); - return; - } - anomalyDetectionIndices.update(); - - User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); - - String user = userInfo.getName(); - List roles = userInfo.getRoles(); - - String resultIndex = jobParameter.getCustomResultIndex(); - if (resultIndex == null) { - runAnomalyDetectionJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - detectorId, - user, - roles, - recorder, - detector - ); - return; - } - ActionListener listener = ActionListener.wrap(r -> { log.debug("Custom index is valid"); }, e -> { - Exception exception = new EndRunException(detectorId, e.getMessage(), true); - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, recorder, detector); - }); - anomalyDetectionIndices.validateCustomIndexForBackendJob(resultIndex, detectorId, user, roles, () -> { - listener.onResponse(true); - runAnomalyDetectionJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - detectorId, - user, - roles, - recorder, - detector - ); - }, listener); - } - - private void runAnomalyDetectionJob( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - String detectorId, - String user, - List roles, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - // using one thread in the write threadpool - try (InjectSecurity injectSecurity = new InjectSecurity(detectorId, settings, client.threadPool().getThreadContext())) { - // Injecting user role to verify if the user has permissions for our API. - injectSecurity.inject(user, roles); - - AnomalyResultRequest request = new AnomalyResultRequest( - detectorId, - detectionStartTime.toEpochMilli(), - executionStartTime.toEpochMilli() - ); - client.execute(AnomalyResultAction.INSTANCE, request, ActionListener.wrap(response -> { - indexAnomalyResult(jobParameter, lockService, lock, detectionStartTime, executionStartTime, response, recorder, detector); - }, exception -> { - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, recorder, detector); - })); - } catch (Exception e) { - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - e, - true, - recorder, - detector - ); - log.error("Failed to execute AD job " + detectorId, e); - } - } - - /** - * Handle exception from anomaly result action. - * - * 1. If exception is {@link EndRunException} - * a). if isEndNow == true, stop AD job and store exception in anomaly result - * b). if isEndNow == false, record count of {@link EndRunException} for this - * detector. If count of {@link EndRunException} exceeds upper limit, will - * stop AD job and store exception in anomaly result; otherwise, just - * store exception in anomaly result, not stop AD job for the detector. - * - * 2. If exception is not {@link EndRunException}, decrease count of - * {@link EndRunException} for the detector and index eception in Anomaly - * result. If exception is {@link InternalFailure}, will not log exception - * stack trace as already logged in {@link AnomalyResultTransportAction}. - * - * TODO: Handle finer granularity exception such as some exception may be - * transient and retry in current job may succeed. Currently, we don't - * know which exception is transient and retryable in - * {@link AnomalyResultTransportAction}. So we don't add backoff retry - * now to avoid bring extra load to cluster, expecially the code start - * process is relatively heavy by sending out 24 queries, initializing - * models, and saving checkpoints. - * Sometimes missing anomaly and notification is not acceptable. For example, - * current detection interval is 1hour, and there should be anomaly in - * current interval, some transient exception may fail current AD job, - * so no anomaly found and user never know it. Then we start next AD job, - * maybe there is no anomaly in next 1hour, user will never know something - * wrong happened. In one word, this is some tradeoff between protecting - * our performance, user experience and what we can do currently. - * - * @param jobParameter scheduled job parameter - * @param lockService lock service - * @param lock lock to run job - * @param detectionStartTime detection start time - * @param executionStartTime detection end time - * @param exception exception - * @param recorder utility to record job execution result - * @param detector associated detector accessor - */ - protected void handleAdException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - Exception exception, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - if (exception instanceof EndRunException) { - log.error("EndRunException happened when executing anomaly result action for " + detectorId, exception); - - if (((EndRunException) exception).isEndNow()) { - // Stop AD job if EndRunException shows we should end job now. - log.info("JobRunner will stop AD job due to EndRunException for {}", detectorId); - stopAdJobForEndRunException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - (EndRunException) exception, - recorder, - detector - ); - } else { - detectorEndRunExceptionCount.compute(detectorId, (k, v) -> { - if (v == null) { - return 1; - } else { - return v + 1; - } - }); - log.info("EndRunException happened for {}", detectorId); - // if AD job failed consecutively due to EndRunException and failed times exceeds upper limit, will stop AD job - if (detectorEndRunExceptionCount.get(detectorId) > maxRetryForEndRunException) { - log - .info( - "JobRunner will stop AD job due to EndRunException retry exceeds upper limit {} for {}", - maxRetryForEndRunException, - detectorId - ); - stopAdJobForEndRunException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - (EndRunException) exception, - recorder, - detector - ); - return; - } - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - exception.getMessage(), - true, - recorder, - detector - ); - } - } else { - detectorEndRunExceptionCount.remove(detectorId); - if (exception instanceof InternalFailure) { - log.error("InternalFailure happened when executing anomaly result action for " + detectorId, exception); - } else { - log.error("Failed to execute anomaly result action for " + detectorId, exception); - } - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - exception, - true, - recorder, - detector - ); - } - } - - private void stopAdJobForEndRunException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - EndRunException exception, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - detectorEndRunExceptionCount.remove(detectorId); - String errorPrefix = exception.isEndNow() - ? "Stopped detector: " - : "Stopped detector as job failed consecutively for more than " + this.maxRetryForEndRunException + " times: "; - String error = errorPrefix + exception.getMessage(); - stopAdJob( - detectorId, - () -> indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - error, - true, - TaskState.STOPPED.name(), - recorder, - detector - ) - ); - } - - private void stopAdJob(String detectorId, ExecutorFunction function) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - ActionListener listener = ActionListener.wrap(response -> { - if (response.isExists()) { - try ( - XContentParser parser = XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - if (job.isEnabled()) { - Job newJob = new Job( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - false, - job.getEnabledTime(), - Instant.now(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(newJob.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)) - .id(detectorId); - - client.index(indexRequest, ActionListener.wrap(indexResponse -> { - if (indexResponse != null && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { - log.info("AD Job was disabled by JobRunner for " + detectorId); - // function.execute(); - } else { - log.warn("Failed to disable AD job for " + detectorId); - } - }, exception -> { log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception); })); - } else { - log.info("AD Job was disabled for " + detectorId); - } - } catch (IOException e) { - log.error("JobRunner failed to stop detector job " + detectorId, e); - } - } else { - log.info("AD Job was not found for " + detectorId); - } - }, exception -> log.error("JobRunner failed to get detector job " + detectorId, exception)); - - client.get(getRequest, ActionListener.runAfter(listener, () -> function.execute())); - } - - private void indexAnomalyResult( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - AnomalyResultResponse response, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - String detectorId = jobParameter.getName(); - detectorEndRunExceptionCount.remove(detectorId); - try { - recorder.indexAnomalyResult(detectionStartTime, executionStartTime, response, detector); - } catch (EndRunException e) { - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, recorder, detector); - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); - } finally { - releaseLock(jobParameter, lockService, lock); - } - - } - - private void indexAnomalyResultException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - Exception exception, - boolean releaseLock, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - try { - String errorMessage = exception instanceof TimeSeriesException - ? exception.getMessage() - : Throwables.getStackTraceAsString(exception); - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - errorMessage, - releaseLock, - recorder, - detector - ); - } catch (Exception e) { - log.error("Failed to index anomaly result for " + jobParameter.getName(), e); - } - } - - private void indexAnomalyResultException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - String errorMessage, - boolean releaseLock, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - errorMessage, - releaseLock, - null, - recorder, - detector - ); - } - - private void indexAnomalyResultException( - Job jobParameter, - LockService lockService, - LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - String errorMessage, - boolean releaseLock, - String taskState, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector - ) { - try { - recorder.indexAnomalyResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); - } finally { - if (releaseLock) { - releaseLock(jobParameter, lockService, lock); - } - } - } - - private void releaseLock(Job jobParameter, LockService lockService, LockModel lock) { - lockService - .release( - lock, - ActionListener.wrap(released -> { log.info("Released lock for AD job {}", jobParameter.getName()); }, exception -> { - log.error("Failed to release lock for AD job: " + jobParameter.getName(), exception); - }) - ); - } -} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index 0fb2fe7fb..5119220b9 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -11,80 +11,56 @@ package org.opensearch.ad; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_PARSE_DETECTOR_MSG; -import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import java.util.List; -import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.core.util.Throwables; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; -import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileRequest; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingRequest; import org.opensearch.ad.transport.RCFPollingResponse; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.SearchHits; -import org.opensearch.search.aggregations.Aggregation; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.AggregationBuilders; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; -import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; -import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; -import org.opensearch.search.aggregations.metrics.InternalCardinality; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ProfileRunner; import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; -public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { +public class AnomalyDetectorProfileRunner extends + ProfileRunner { + private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); - private Client client; - private SecurityClientUtil clientUtil; - private NamedXContentRegistry xContentRegistry; - private DiscoveryNodeFilterer nodeFilter; - private final TransportService transportService; - private final ADTaskManager adTaskManager; - private final int maxTotalEntitiesToTrack; public AnomalyDetectorProfileRunner( Client client, @@ -93,300 +69,133 @@ public AnomalyDetectorProfileRunner( DiscoveryNodeFilterer nodeFilter, long requiredSamples, TransportService transportService, - ADTaskManager adTaskManager + ADTaskManager adTaskManager, + ADTaskProfileRunner taskProfileRunner ) { - super(requiredSamples); - this.client = client; - this.clientUtil = clientUtil; - this.xContentRegistry = xContentRegistry; - this.nodeFilter = nodeFilter; - if (requiredSamples <= 0) { - throw new IllegalArgumentException("required samples should be a positive number, but was " + requiredSamples); - } - this.transportService = transportService; - this.adTaskManager = adTaskManager; - this.maxTotalEntitiesToTrack = TimeSeriesSettings.MAX_TOTAL_ENTITIES_TO_TRACK; + super( + client, + clientUtil, + xContentRegistry, + nodeFilter, + requiredSamples, + transportService, + adTaskManager, + AnalysisType.AD, + ADTaskType.REALTIME_TASK_TYPES, + ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES, + ADNumericSetting.maxCategoricalFields(), + ProfileName.AD_TASK, + ADProfileAction.INSTANCE, + AnomalyDetector::parse, + taskProfileRunner + ); } - public void profile(String detectorId, ActionListener listener, Set profilesToCollect) { - if (profilesToCollect.isEmpty()) { - listener.onFailure(new IllegalArgumentException(ADCommonMessages.EMPTY_PROFILES_COLLECT)); - return; - } - calculateTotalResponsesToWait(detectorId, profilesToCollect, listener); - } - - private void calculateTotalResponsesToWait( - String detectorId, - Set profilesToCollect, - ActionListener listener - ) { - GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client.get(getDetectorRequest, ActionListener.wrap(getDetectorResponse -> { - if (getDetectorResponse != null && getDetectorResponse.isExists()) { - try ( - XContentParser xContentParser = XContentType.JSON - .xContent() - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getDetectorResponse.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); - AnomalyDetector detector = AnomalyDetector.parse(xContentParser, detectorId); - prepareProfile(detector, listener, profilesToCollect); - } catch (Exception e) { - logger.error(FAIL_TO_PARSE_DETECTOR_MSG + detectorId, e); - listener.onFailure(new OpenSearchStatusException(FAIL_TO_PARSE_DETECTOR_MSG + detectorId, BAD_REQUEST)); - } - } else { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, BAD_REQUEST)); - } - }, exception -> { - logger.error(FAIL_TO_FIND_CONFIG_MSG + detectorId, exception); - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, INTERNAL_SERVER_ERROR)); - })); + @Override + protected DetectorProfile.Builder createProfileBuilder() { + return new DetectorProfile.Builder(); } - private void prepareProfile( - AnomalyDetector detector, - ActionListener listener, - Set profilesToCollect - ) { - String detectorId = detector.getId(); - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, detectorId); - client.get(getRequest, ActionListener.wrap(getResponse -> { - if (getResponse != null && getResponse.isExists()) { - try ( - XContentParser parser = XContentType.JSON - .xContent() - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - long enabledTimeMs = job.getEnabledTime().toEpochMilli(); - - boolean isMultiEntityDetector = detector.isHighCardinality(); - - int totalResponsesToWait = 0; - if (profilesToCollect.contains(DetectorProfileName.ERROR)) { - totalResponsesToWait++; - } - - // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide - // when to consolidate results and return to users - if (isMultiEntityDetector) { - if (profilesToCollect.contains(DetectorProfileName.TOTAL_ENTITIES)) { - totalResponsesToWait++; - } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS) - || profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) - || profilesToCollect.contains(DetectorProfileName.STATE)) { - totalResponsesToWait++; - } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + @Override + protected void prepareProfile(Config config, ActionListener listener, Set profilesToCollect) { + boolean isHC = config.isHighCardinality(); + if (isHC) { + super.prepareProfile(config, listener, profilesToCollect); + } else { + String configId = config.getId(); + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, configId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + long enabledTimeMs = job.getEnabledTime().toEpochMilli(); + + int totalResponsesToWait = 0; + if (profilesToCollect.contains(ProfileName.ERROR)) { totalResponsesToWait++; } - } else { - if (profilesToCollect.contains(DetectorProfileName.STATE) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + + if (profilesToCollect.contains(ProfileName.STATE) || profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { totalResponsesToWait++; } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS)) { + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS)) { totalResponsesToWait++; } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + if (profilesToCollect.contains(ProfileName.AD_TASK)) { totalResponsesToWait++; } - } - - MultiResponsesDelegateActionListener delegateListener = - new MultiResponsesDelegateActionListener( - listener, - totalResponsesToWait, - ADCommonMessages.FAIL_FETCH_ERR_MSG + detectorId, - false - ); - if (profilesToCollect.contains(DetectorProfileName.ERROR)) { - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, adTask -> { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - if (adTask.isPresent()) { - long lastUpdateTimeMs = adTask.get().getLastUpdateTime().toEpochMilli(); - // if state index hasn't been updated, we should not use the error field - // For example, before a detector is enabled, if the error message contains - // the phrase "stopped due to blah", we should not show this when the detector - // is enabled. - if (lastUpdateTimeMs > enabledTimeMs && adTask.get().getError() != null) { - profileBuilder.error(adTask.get().getError()); + MultiResponsesDelegateActionListener delegateListener = + new MultiResponsesDelegateActionListener( + listener, + totalResponsesToWait, + CommonMessages.FAIL_FETCH_ERR_MSG + configId, + false + ); + if (profilesToCollect.contains(ProfileName.ERROR)) { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, realTimeTaskTypes, task -> { + DetectorProfile.Builder profileBuilder = createProfileBuilder(); + if (task.isPresent()) { + long lastUpdateTimeMs = task.get().getLastUpdateTime().toEpochMilli(); + + // if state index hasn't been updated, we should not use the error field + // For example, before a detector is enabled, if the error message contains + // the phrase "stopped due to blah", we should not show this when the detector + // is enabled. + if (lastUpdateTimeMs > enabledTimeMs && task.get().getError() != null) { + profileBuilder.error(task.get().getError()); + } + delegateListener.onResponse(profileBuilder.build()); + } else { + // detector state for this detector does not exist + delegateListener.onResponse(profileBuilder.build()); } - delegateListener.onResponse(profileBuilder.build()); - } else { - // detector state for this detector does not exist - delegateListener.onResponse(profileBuilder.build()); - } - }, transportService, false, delegateListener); - } - - // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide - // when to consolidate results and return to users - if (isMultiEntityDetector) { - if (profilesToCollect.contains(DetectorProfileName.TOTAL_ENTITIES)) { - profileEntityStats(delegateListener, detector); - } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS) - || profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) - || profilesToCollect.contains(DetectorProfileName.STATE)) { - profileModels(detector, profilesToCollect, job, true, delegateListener); + }, transportService, false, delegateListener); } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { - adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, null, delegateListener); - } - } else { - if (profilesToCollect.contains(DetectorProfileName.STATE) - || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { - profileStateRelated(detector, delegateListener, job.isEnabled(), profilesToCollect); + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + + if (profilesToCollect.contains(ProfileName.STATE) || profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { + profileStateRelated(config, delegateListener, job.isEnabled(), profilesToCollect); } - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) - || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) - || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) - || profilesToCollect.contains(DetectorProfileName.MODELS)) { - profileModels(detector, profilesToCollect, job, false, delegateListener); + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS)) { + profileModels(config, profilesToCollect, job, false, delegateListener); } - if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { - adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, null, delegateListener); + if (profilesToCollect.contains(ProfileName.AD_TASK)) { + getLatestHistoricalTaskProfile(configId, transportService, null, delegateListener); } - } - - } catch (Exception e) { - logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG, e); - listener.onFailure(e); - } - } else { - onGetDetectorForPrepare(detectorId, listener, profilesToCollect); - } - }, exception -> { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - logger.info(exception.getMessage()); - onGetDetectorForPrepare(detectorId, listener, profilesToCollect); - } else { - logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG + detectorId); - listener.onFailure(exception); - } - })); - } - private void profileEntityStats(MultiResponsesDelegateActionListener listener, AnomalyDetector detector) { - List categoryField = detector.getCategoryFields(); - if (!detector.isHighCardinality() || categoryField.size() > ADNumericSetting.maxCategoricalFields()) { - listener.onResponse(new DetectorProfile.Builder().build()); - } else { - if (categoryField.size() == 1) { - // Run a cardinality aggregation to count the cardinality of single category fields - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - CardinalityAggregationBuilder aggBuilder = new CardinalityAggregationBuilder(ADCommonName.TOTAL_ENTITIES); - aggBuilder.field(categoryField.get(0)); - searchSourceBuilder.aggregation(aggBuilder); - - SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { - Map aggMap = searchResponse.getAggregations().asMap(); - InternalCardinality totalEntities = (InternalCardinality) aggMap.get(ADCommonName.TOTAL_ENTITIES); - long value = totalEntities.getValue(); - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - DetectorProfile profile = profileBuilder.totalEntities(value).build(); - listener.onResponse(profile); - }, searchException -> { - logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId()); - listener.onFailure(searchException); - }); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - request, - client::search, - detector.getId(), - client, - AnalysisType.AD, - searchResponseListener - ); - } else { - // Run a composite query and count the number of buckets to decide cardinality of multiple category fields - AggregationBuilder bucketAggs = AggregationBuilders - .composite( - ADCommonName.TOTAL_ENTITIES, - detector - .getCategoryFields() - .stream() - .map(f -> new TermsValuesSourceBuilder(f).field(f)) - .collect(Collectors.toList()) - ) - .size(maxTotalEntitiesToTrack); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(bucketAggs).trackTotalHits(false).size(0); - SearchRequest searchRequest = new SearchRequest() - .indices(detector.getIndices().toArray(new String[0])) - .source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - Aggregations aggs = searchResponse.getAggregations(); - if (aggs == null) { - // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date - // with - // the large amounts of changes there). For example, they may change to if there are results return it; otherwise - // return - // null instead of an empty Aggregations as they currently do. - logger.warn("Unexpected null aggregation."); - listener.onResponse(profileBuilder.totalEntities(0L).build()); - return; + } catch (Exception e) { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG, e); + listener.onFailure(e); } - - Aggregation aggrResult = aggs.get(ADCommonName.TOTAL_ENTITIES); - if (aggrResult == null) { - listener.onFailure(new IllegalArgumentException("Fail to find valid aggregation result")); - return; - } - - CompositeAggregation compositeAgg = (CompositeAggregation) aggrResult; - DetectorProfile profile = profileBuilder.totalEntities(Long.valueOf(compositeAgg.getBuckets().size())).build(); - listener.onResponse(profile); - }, searchException -> { - logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId()); - listener.onFailure(searchException); - }); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - detector.getId(), - client, - AnalysisType.AD, - searchResponseListener - ); - } - - } - } - - private void onGetDetectorForPrepare(String detectorId, ActionListener listener, Set profiles) { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - if (profiles.contains(DetectorProfileName.STATE)) { - profileBuilder.state(DetectorState.DISABLED); - } - if (profiles.contains(DetectorProfileName.AD_TASK)) { - adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, profileBuilder.build(), listener); - } else { - listener.onResponse(profileBuilder.build()); + } else { + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + logger.info(exception.getMessage()); + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } else { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG + configId); + listener.onFailure(exception); + } + })); } } @@ -395,141 +204,29 @@ private void onGetDetectorForPrepare(String detectorId, ActionListener listener, boolean enabled, - Set profilesToCollect + Set profilesToCollect ) { if (enabled) { - RCFPollingRequest request = new RCFPollingRequest(detector.getId()); - client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener)); + RCFPollingRequest request = new RCFPollingRequest(config.getId()); + client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(config, profilesToCollect, listener)); } else { DetectorProfile.Builder builder = new DetectorProfile.Builder(); - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.DISABLED); + if (profilesToCollect.contains(ProfileName.STATE)) { + builder.state(ConfigState.DISABLED); } listener.onResponse(builder.build()); } } - private void profileModels( - AnomalyDetector detector, - Set profiles, - Job job, - boolean forMultiEntityDetector, - MultiResponsesDelegateActionListener listener - ) { - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - ProfileRequest profileRequest = new ProfileRequest(detector.getId(), profiles, forMultiEntityDetector, dataNodes); - client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init progress - } - - private ActionListener onModelResponse( - AnomalyDetector detector, - Set profilesToCollect, - Job job, - MultiResponsesDelegateActionListener listener - ) { - boolean isMultientityDetector = detector.isHighCardinality(); - return ActionListener.wrap(profileResponse -> { - DetectorProfile.Builder profile = new DetectorProfile.Builder(); - if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE)) { - profile.coordinatingNode(profileResponse.getCoordinatingNode()); - } - if (profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE)) { - profile.shingleSize(profileResponse.getShingleSize()); - } - if (profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES)) { - profile.totalSizeInBytes(profileResponse.getTotalSizeInBytes()); - } - if (profilesToCollect.contains(DetectorProfileName.MODELS)) { - profile.modelProfile(profileResponse.getModelProfile()); - profile.modelCount(profileResponse.getModelCount()); - } - if (isMultientityDetector && profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES)) { - profile.activeEntities(profileResponse.getActiveEntities()); - } - - if (isMultientityDetector - && (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) - || profilesToCollect.contains(DetectorProfileName.STATE))) { - profileMultiEntityDetectorStateRelated(job, profilesToCollect, profileResponse, profile, detector, listener); - } else { - listener.onResponse(profile.build()); - } - }, listener::onFailure); - } - - private void profileMultiEntityDetectorStateRelated( - Job job, - Set profilesToCollect, - ProfileResponse profileResponse, - DetectorProfile.Builder profileBuilder, - AnomalyDetector detector, - MultiResponsesDelegateActionListener listener - ) { - if (job.isEnabled()) { - if (profileResponse.getTotalUpdates() < requiredSamples) { - // need to double check since what ProfileResponse returns is the highest priority entity currently in memory, but - // another entity might have already been initialized and sit somewhere else (in memory or on disk). - long enabledTime = job.getEnabledTime().toEpochMilli(); - long totalUpdates = profileResponse.getTotalUpdates(); - ProfileUtil - .confirmDetectorRealtimeInitStatus( - detector, - enabledTime, - client, - onInittedEver(enabledTime, profileBuilder, profilesToCollect, detector, totalUpdates, listener) - ); - } else { - createRunningStateAndInitProgress(profilesToCollect, profileBuilder); - listener.onResponse(profileBuilder.build()); - } - } else { - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - profileBuilder.state(DetectorState.DISABLED); - } - listener.onResponse(profileBuilder.build()); - } - } - - private ActionListener onInittedEver( - long lastUpdateTimeMs, - DetectorProfile.Builder profileBuilder, - Set profilesToCollect, - AnomalyDetector detector, - long totalUpdates, - MultiResponsesDelegateActionListener listener - ) { - return ActionListener.wrap(searchResponse -> { - SearchHits hits = searchResponse.getHits(); - if (hits.getTotalHits().value == 0L) { - processInitResponse(detector, profilesToCollect, totalUpdates, false, profileBuilder, listener); - } else { - createRunningStateAndInitProgress(profilesToCollect, profileBuilder); - listener.onResponse(profileBuilder.build()); - } - }, exception -> { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - // anomaly result index is not created yet - processInitResponse(detector, profilesToCollect, totalUpdates, false, profileBuilder, listener); - } else { - logger - .error( - "Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}", - detector.getId() - ); - listener.onFailure(exception); - } - }); - } - /** * Listener for polling rcf updates through transport messaging * @param detector anomaly detector @@ -538,8 +235,8 @@ private ActionListener onInittedEver( * @return Listener for polling rcf updates through transport messaging */ private ActionListener onPollRCFUpdates( - AnomalyDetector detector, - Set profilesToCollect, + Config detector, + Set profilesToCollect, MultiResponsesDelegateActionListener listener ) { return ActionListener.wrap(rcfPollResponse -> { @@ -547,7 +244,7 @@ private ActionListener onPollRCFUpdates( if (totalUpdates < requiredSamples) { processInitResponse(detector, profilesToCollect, totalUpdates, false, new DetectorProfile.Builder(), listener); } else { - DetectorProfile.Builder builder = new DetectorProfile.Builder(); + DetectorProfile.Builder builder = createProfileBuilder(); createRunningStateAndInitProgress(profilesToCollect, builder); listener.onResponse(builder.build()); } @@ -570,7 +267,7 @@ private ActionListener onPollRCFUpdates( // a detector before cold start finishes, where the actual // initialization time may be much shorter if sufficient historical // data exists. - processInitResponse(detector, profilesToCollect, 0L, true, new DetectorProfile.Builder(), listener); + processInitResponse(detector, profilesToCollect, 0L, true, createProfileBuilder(), listener); } else { logger.error(new ParameterizedMessage("Fail to get init progress through messaging for {}", detector.getId()), exception); listener.onFailure(exception); @@ -578,40 +275,4 @@ private ActionListener onPollRCFUpdates( }); } - private void createRunningStateAndInitProgress(Set profilesToCollect, DetectorProfile.Builder builder) { - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.RUNNING).build(); - } - - if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { - InitProgressProfile initProgress = new InitProgressProfile("100%", 0, 0); - builder.initProgress(initProgress); - } - } - - private void processInitResponse( - AnomalyDetector detector, - Set profilesToCollect, - long totalUpdates, - boolean hideMinutesLeft, - DetectorProfile.Builder builder, - MultiResponsesDelegateActionListener listener - ) { - if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.INIT); - } - - if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { - if (hideMinutesLeft) { - InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, 0); - builder.initProgress(initProgress); - } else { - long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes(); - InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, intervalMins); - builder.initProgress(initProgress); - } - } - - listener.onResponse(builder.build()); - } } diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java index c5336316c..185c1c884 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java @@ -24,16 +24,16 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchSecurityException; -import org.opensearch.ad.constant.CommonValue; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.Features; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.EntityAnomalyResult; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.FeatureData; @@ -45,11 +45,11 @@ public final class AnomalyDetectorRunner { private final Logger logger = LogManager.getLogger(AnomalyDetectorRunner.class); - private final ModelManager modelManager; + private final ADModelManager modelManager; private final FeatureManager featureManager; private final int maxPreviewResults; - public AnomalyDetectorRunner(ModelManager modelManager, FeatureManager featureManager, int maxPreviewResults) { + public AnomalyDetectorRunner(ADModelManager modelManager, FeatureManager featureManager, int maxPreviewResults) { this.modelManager = modelManager; this.featureManager = featureManager; this.maxPreviewResults = maxPreviewResults; @@ -166,24 +166,24 @@ private List parsePreviewResult( AnomalyResult result; if (results != null && results.size() > i) { - ThresholdingResult thresholdingResult = results.get(i); - List resultsToSave = thresholdingResult - .toIndexableResults( - detector, - Instant.ofEpochMilli(timeRange.getKey()), - Instant.ofEpochMilli(timeRange.getValue()), - null, - null, - featureDatas, - Optional.ofNullable(entity), - CommonValue.NO_SCHEMA_VERSION, - null, - null, - null + anomalyResults + .addAll( + results + .get(i) + .toIndexableResults( + detector, + Instant.ofEpochMilli(timeRange.getKey()), + Instant.ofEpochMilli(timeRange.getValue()), + null, + null, + featureDatas, + Optional.ofNullable(entity), + CommonValue.NO_SCHEMA_VERSION, + null, + null, + null + ) ); - for (AnomalyResult r : resultsToSave) { - anomalyResults.add(r); - } } else { result = new AnomalyResult( detector.getId(), diff --git a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java index 0f8c6fca4..e1d042267 100644 --- a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java +++ b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java @@ -11,373 +11,121 @@ package org.opensearch.ad; -import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; - import java.time.Instant; import java.util.ArrayList; -import java.util.HashSet; import java.util.Optional; -import java.util.Set; -import java.util.concurrent.TimeUnit; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.update.UpdateResponse; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileRequest; -import org.opensearch.ad.transport.RCFPollingAction; -import org.opensearch.ad.transport.RCFPollingRequest; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.search.SearchHits; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.model.FeatureData; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.ExceptionUtil; -public class ExecuteADResultResponseRecorder { - private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); +public class ExecuteADResultResponseRecorder extends + ExecuteResultResponseRecorder { - private ADIndexManagement anomalyDetectionIndices; - private AnomalyIndexHandler anomalyResultHandler; - private ADTaskManager adTaskManager; - private DiscoveryNodeFilterer nodeFilter; - private ThreadPool threadPool; - private Client client; - private NodeStateManager nodeStateManager; - private ADTaskCacheManager adTaskCacheManager; - private int rcfMinSamples; + private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); public ExecuteADResultResponseRecorder( - ADIndexManagement anomalyDetectionIndices, - AnomalyIndexHandler anomalyResultHandler, - ADTaskManager adTaskManager, + ADIndexManagement indexManagement, + ResultBulkIndexingHandler resultHandler, + ADTaskManager taskManager, DiscoveryNodeFilterer nodeFilter, ThreadPool threadPool, Client client, NodeStateManager nodeStateManager, - ADTaskCacheManager adTaskCacheManager, + ADTaskCacheManager taskCacheManager, int rcfMinSamples ) { - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.anomalyResultHandler = anomalyResultHandler; - this.adTaskManager = adTaskManager; - this.nodeFilter = nodeFilter; - this.threadPool = threadPool; - this.client = client; - this.nodeStateManager = nodeStateManager; - this.adTaskCacheManager = adTaskCacheManager; - this.rcfMinSamples = rcfMinSamples; + super( + indexManagement, + resultHandler, + taskManager, + nodeFilter, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + client, + nodeStateManager, + taskCacheManager, + rcfMinSamples, + ADIndex.RESULT, + AnalysisType.AD, + ADProfileAction.INSTANCE + ); } - public void indexAnomalyResult( - Instant detectionStartTime, - Instant executionStartTime, - AnomalyResultResponse response, - AnomalyDetector detector + @Override + protected AnomalyResult createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user ) { - String detectorId = detector.getId(); - try { - // skipping writing to the result index if not necessary - // For a single-entity detector, the result is not useful if error is null - // and rcf score (thus anomaly grade/confidence) is null. - // For a HCAD detector, we don't need to save on the detector level. - // We return 0 or Double.NaN rcf score if there is no error. - if ((response.getAnomalyScore() <= 0 || Double.isNaN(response.getAnomalyScore())) && response.getError() == null) { - updateRealtimeTask(response, detectorId); - return; - } - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = detector.getUser(); - - if (response.getError() != null) { - log.info("Anomaly result action run successfully for {} with error {}", detectorId, response.getError()); - } - - AnomalyResult anomalyResult = response - .toAnomalyResult( - detectorId, - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - user, - response.getError() - ); - - String resultIndex = detector.getCustomResultIndex(); - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - updateRealtimeTask(response, detectorId); - } catch (EndRunException e) { - throw e; - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); - } + return new AnomalyResult( + configId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executeEndTime, + Instant.now(), + errorMessage, + Optional.empty(), // single-stream detectors have no entity + user, + indexManagement.getSchemaVersion(resultIndex), + null // no model id + ); } /** * Update real time task (one document per detector in state index). If the real-time task has no changes compared with local cache, - * the task won't update. Task only updates when the state changed, or any error happened, or AD job stopped. Task is mainly consumed - * by the front-end to track detector status. For single-stream detectors, we embed model total updates in AnomalyResultResponse and - * update state accordingly. For HCAD, we won't wait for model finishing updating before returning a response to the job scheduler + * the task won't update. Task only updates when the state changed, or any error happened, or job stopped. Task is mainly consumed + * by the front-end to track analysis status. For single-stream analyses, we embed model total updates in ResultResponse and + * update state accordingly. For HC analysis, we won't wait for model finishing updating before returning a response to the job scheduler * since it might be long before all entities finish execution. So we don't embed model total updates in AnomalyResultResponse. * Instead, we issue a profile request to poll each model node and get the maximum total updates among all models. * @param response response returned from executing AnomalyResultAction - * @param detectorId Detector Id + * @param configId config Id */ - private void updateRealtimeTask(AnomalyResultResponse response, String detectorId) { - if (response.isHCDetector() != null && response.isHCDetector()) { - if (adTaskManager.skipUpdateHCRealtimeTask(detectorId, response.getError())) { + @Override + protected void updateRealtimeTask(ResultResponse response, String configId) { + if (response.isHC() != null && response.isHC()) { + if (taskManager.skipUpdateRealtimeTask(configId, response.getError())) { return; } - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - Set profiles = new HashSet<>(); - profiles.add(DetectorProfileName.INIT_PROGRESS); - ProfileRequest profileRequest = new ProfileRequest(detectorId, profiles, true, dataNodes); - Runnable profileHCInitProgress = () -> { - client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> { - log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates()); - updateLatestRealtimeTask(detectorId, null, r.getTotalUpdates(), response.getIntervalInMinutes(), response.getError()); - }, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); })); - }; - if (!adTaskManager.isHCRealtimeTaskStartInitializing(detectorId)) { - // real time init progress is 0 may mean this is a newly started detector - // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to - // finish. We can change the detector running immediately instead of waiting for the next interval. - threadPool - .schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); - } else { - profileHCInitProgress.run(); - } - + delayedUpdate(response, configId); } else { log .debug( "Update latest realtime task for single stream detector {}, total updates: {}", - detectorId, + configId, response.getRcfTotalUpdates() ); - updateLatestRealtimeTask(detectorId, null, response.getRcfTotalUpdates(), response.getIntervalInMinutes(), response.getError()); - } - } - - private void updateLatestRealtimeTask( - String detectorId, - String taskState, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error - ) { - // Don't need info as this will be printed repeatedly in each interval - ActionListener listener = ActionListener.wrap(r -> { - if (r != null) { - log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); - } - }, e -> { - if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { - // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. - log.error("Can't find latest realtime task of detector " + detectorId); - adTaskManager.removeRealtimeTaskCache(detectorId); - } else { - log.error("Failed to update latest realtime task for detector " + detectorId, e); - } - }); - - // rcfTotalUpdates is null when we save exception messages - if (!adTaskCacheManager.hasQueriedResultIndex(detectorId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { - // confirm the total updates number since it is possible that we have already had results after job enabling time - // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. - confirmTotalRCFUpdatesFound( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - ActionListener - .wrap( - r -> adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - r, - detectorIntervalInMinutes, - error, - listener - ), - e -> { - log.error("Fail to confirm rcf update", e); - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - listener - ); - } - ) - ); - } else { - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - listener - ); - } - } - - /** - * The function is not only indexing the result with the exception, but also updating the task state after - * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. - * - * @param detectionStartTime execution start time - * @param executionStartTime execution end time - * @param errorMessage Error message to record - * @param taskState AD task state (e.g., stopped) - * @param detector Detector config accessor - */ - public void indexAnomalyResultException( - Instant detectionStartTime, - Instant executionStartTime, - String errorMessage, - String taskState, - AnomalyDetector detector - ) { - String detectorId = detector.getId(); - try { - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = detector.getUser(); - - AnomalyResult anomalyResult = new AnomalyResult( - detectorId, - null, // no task id - new ArrayList(), - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - errorMessage, - Optional.empty(), // single-stream detectors have no entity - user, - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - null // no model id + updateLatestRealtimeTask( + configId, + null, + response.getRcfTotalUpdates(), + response.getConfigIntervalInMinutes(), + response.getError() ); - String resultIndex = detector.getCustomResultIndex(); - if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) { - // Set result index as null, will write exception to default result index. - anomalyResultHandler.index(anomalyResult, detectorId, null); - } else { - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - } - - if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isHighCardinality()) { - // single stream detector raises ResourceNotFoundException containing CommonErrorMessages.NO_CHECKPOINT_ERR_MSG - // when there is no checkpoint. - // Delay real time cache update by one minute so we will have trained models by then and update the state - // document accordingly. - threadPool.schedule(() -> { - RCFPollingRequest request = new RCFPollingRequest(detectorId); - client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { - long totalUpdates = rcfPollResponse.getTotalUpdates(); - // if there are updates, don't record failures - updateLatestRealtimeTask( - detectorId, - taskState, - totalUpdates, - detector.getIntervalInMinutes(), - totalUpdates > 0 ? "" : errorMessage - ); - }, e -> { - log.error("Fail to execute RCFRollingAction", e); - updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); - })); - }, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); - } else { - updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); - } - - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); } } - - private void confirmTotalRCFUpdatesFound( - String detectorId, - String taskState, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error, - ActionListener listener - ) { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")); - return; - } - nodeStateManager.getJob(detectorId, ActionListener.wrap(jobOptional -> { - if (!jobOptional.isPresent()) { - listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")); - return; - } - - ProfileUtil - .confirmDetectorRealtimeInitStatus( - (AnomalyDetector) detectorOptional.get(), - jobOptional.get().getEnabledTime().toEpochMilli(), - client, - ActionListener.wrap(searchResponse -> { - ActionListener.completeWith(listener, () -> { - SearchHits hits = searchResponse.getHits(); - Long correctedTotalUpdates = rcfTotalUpdates; - if (hits.getTotalHits().value > 0L) { - // correct the number if we have already had results after job enabling time - // so that the detector won't stay initialized - correctedTotalUpdates = Long.valueOf(rcfMinSamples); - } - adTaskCacheManager.markResultIndexQueried(detectorId); - return correctedTotalUpdates; - }); - }, exception -> { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - // anomaly result index is not created yet - adTaskCacheManager.markResultIndexQueried(detectorId); - listener.onResponse(0L); - } else { - listener.onFailure(exception); - } - }) - ); - }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")))); - }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")))); - } } diff --git a/src/main/java/org/opensearch/ad/ProfileUtil.java b/src/main/java/org/opensearch/ad/ProfileUtil.java deleted file mode 100644 index 3d77924d0..000000000 --- a/src/main/java/org/opensearch/ad/ProfileUtil.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.client.Client; -import org.opensearch.core.action.ActionListener; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.ExistsQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.timeseries.constant.CommonName; - -public class ProfileUtil { - /** - * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. - * Note this function is only meant to check for status of real time analysis. - * - * @param detectorId detector id - * @param enabledTime the time when AD job is enabled in milliseconds - * @return the search request - */ - private static SearchRequest createRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { - BoolQueryBuilder filterQuery = new BoolQueryBuilder(); - filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); - filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); - filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); - // Historical analysis result also stored in result index, which has non-null task_id. - // For realtime detection result, we should filter task_id == null - ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); - filterQuery.mustNot(taskIdExistsFilter); - - SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); - - SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); - request.source(source); - if (resultIndex != null) { - request.indices(resultIndex); - } - return request; - } - - public static void confirmDetectorRealtimeInitStatus( - AnomalyDetector detector, - long enabledTime, - Client client, - ActionListener listener - ) { - SearchRequest searchLatestResult = createRealtimeInittedEverRequest(detector.getId(), enabledTime, detector.getCustomResultIndex()); - client.search(searchLatestResult, listener); - } -} diff --git a/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java b/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java new file mode 100644 index 000000000..828146516 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.CacheBuffer; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * We use a layered cache to manage active entities’ states. We have a two-level + * cache that stores active entity states in each node. Each detector has its + * dedicated cache that stores ten (dynamically adjustable) entities’ states per + * node. A detector’s hottest entities load their states in the dedicated cache. + * If less than 10 entities use the dedicated cache, the secondary cache can use + * the rest of the free memory available to AD. The secondary cache is a shared + * memory among all detectors for the long tail. The shared cache size is 10% + * heap minus all of the dedicated cache consumed by single-entity and multi-entity + * detectors. The shared cache’s size shrinks as the dedicated cache is filled + * up or more detectors are started. + * + * Implementation-wise, both dedicated cache and shared cache are stored in items + * and minimumCapacity controls the boundary. If items size is equals to or less + * than minimumCapacity, consider items as dedicated cache; otherwise, consider + * top minimumCapacity active entities (last X entities in priorityList) as in dedicated + * cache and all others in shared cache. + */ +public class ADCacheBuffer extends + CacheBuffer { + + public ADCacheBuffer( + int minimumCapacity, + Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, + Duration modelTtl, + long memoryConsumptionPerEntity, + ADCheckpointWriteWorker checkpointWriteQueue, + ADCheckpointMaintainWorker checkpointMaintainQueue, + String configId, + long intervalSecs + ) { + super( + minimumCapacity, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + configId, + intervalSecs, + Origin.REAL_TIME_DETECTOR + ); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java b/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java new file mode 100644 index 000000000..e71c89962 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.caching; + +import org.opensearch.timeseries.caching.CacheProvider; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Allows Guice dependency based on types. Otherwise, Guice cannot + * decide which instance to inject based on generic types of CacheProvider + * + */ +public class ADCacheProvider extends CacheProvider { + +} diff --git a/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java b/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java new file mode 100644 index 000000000..5bda770be --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java @@ -0,0 +1,118 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Optional; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.PriorityCache; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADPriorityCache extends + PriorityCache { + private ADCheckpointWriteWorker checkpointWriteQueue; + private ADCheckpointMaintainWorker checkpointMaintainQueue; + + public ADPriorityCache( + ADCheckpointDao checkpointDao, + int hcDedicatedCacheSize, + Setting checkpointTtl, + int maxInactiveStates, + MemoryTracker memoryTracker, + int numberOfTrees, + Clock clock, + ClusterService clusterService, + Duration modelTtl, + ThreadPool threadPool, + int maintenanceFreqConstant, + Settings settings, + Setting checkpointSavingFreq, + ADCheckpointWriteWorker checkpointWriteQueue, + ADCheckpointMaintainWorker checkpointMaintainQueue + ) { + super( + checkpointDao, + hcDedicatedCacheSize, + checkpointTtl, + maxInactiveStates, + memoryTracker, + numberOfTrees, + clock, + clusterService, + modelTtl, + threadPool, + AD_THREAD_POOL_NAME, + maintenanceFreqConstant, + settings, + checkpointSavingFreq, + Origin.REAL_TIME_DETECTOR, + AD_DEDICATED_CACHE_SIZE, + AD_MODEL_MAX_SIZE_PERCENTAGE + ); + + this.checkpointWriteQueue = checkpointWriteQueue; + this.checkpointMaintainQueue = checkpointMaintainQueue; + } + + @Override + protected ADCacheBuffer createEmptyCacheBuffer(Config detector, long memoryConsumptionPerEntity) { + return new ADCacheBuffer( + detector.isHighCardinality() ? hcDedicatedCacheSize : 1, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + detector.getId(), + detector.getIntervalInSeconds() + ); + } + + @Override + protected ModelState createEmptyModelState(String modelId, String detectorId) { + return new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + 0, + null, + Optional.empty(), + new ArrayDeque<>() + ); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/CacheProvider.java b/src/main/java/org/opensearch/ad/caching/CacheProvider.java deleted file mode 100644 index ab8fd191c..000000000 --- a/src/main/java/org/opensearch/ad/caching/CacheProvider.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.caching; - -import org.opensearch.common.inject.Provider; - -/** - * A wrapper to call concrete implementation of caching. Used in transport - * action. Don't use interface because transport action handler constructor - * requires a concrete class as input. - * - */ -public class CacheProvider implements Provider { - private EntityCache cache; - - public CacheProvider() { - - } - - @Override - public EntityCache get() { - return cache; - } - - public void set(EntityCache cache) { - this.cache = cache; - } -} diff --git a/src/main/java/org/opensearch/ad/caching/EntityCache.java b/src/main/java/org/opensearch/ad/caching/EntityCache.java deleted file mode 100644 index 287994efd..000000000 --- a/src/main/java/org/opensearch/ad/caching/EntityCache.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.caching; - -import java.util.Collection; -import java.util.List; -import java.util.Optional; - -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.ad.DetectorModelSize; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.timeseries.CleanState; -import org.opensearch.timeseries.MaintenanceState; -import org.opensearch.timeseries.model.Entity; - -public interface EntityCache extends MaintenanceState, CleanState, DetectorModelSize { - /** - * Get the ModelState associated with the entity. May or may not load the - * ModelState depending on the underlying cache's eviction policy. - * - * @param modelId Model Id - * @param detector Detector config object - * @return the ModelState associated with the model or null if no cached item - * for the entity - */ - ModelState get(String modelId, AnomalyDetector detector); - - /** - * Get the number of active entities of a detector - * @param detector Detector Id - * @return The number of active entities - */ - int getActiveEntities(String detector); - - /** - * - * @return total active entities in the cache - */ - int getTotalActiveEntities(); - - /** - * Whether an entity is active or not - * @param detectorId The Id of the detector that an entity belongs to - * @param entityModelId Entity model Id - * @return Whether an entity is active or not - */ - boolean isActive(String detectorId, String entityModelId); - - /** - * Get total updates of detector's most active entity's RCF model. - * - * @param detectorId detector id - * @return RCF model total updates of most active entity. - */ - long getTotalUpdates(String detectorId); - - /** - * Get RCF model total updates of specific entity - * - * @param detectorId detector id - * @param entityModelId entity model id - * @return RCF model total updates of specific entity. - */ - long getTotalUpdates(String detectorId, String entityModelId); - - /** - * Gets modelStates of all model hosted on a node - * - * @return list of modelStates - */ - List> getAllModels(); - - /** - * Return when the last active time of an entity's state. - * - * If the entity's state is active in the cache, the value indicates when the cache - * is lastly accessed (get/put). If the entity's state is inactive in the cache, - * the value indicates when the cache state is created or when the entity is evicted - * from active entity cache. - * - * @param detectorId The Id of the detector that an entity belongs to - * @param entityModelId Entity's Model Id - * @return if the entity is in the cache, return the timestamp in epoch - * milliseconds when the entity's state is lastly used. Otherwise, return -1. - */ - long getLastActiveMs(String detectorId, String entityModelId); - - /** - * Release memory when memory circuit breaker is open - */ - void releaseMemoryForOpenCircuitBreaker(); - - /** - * Select candidate entities for which we can load models - * @param cacheMissEntities Cache miss entities - * @param detectorId Detector Id - * @param detector Detector object - * @return A list of entities that are admitted into the cache as a result of the - * update and the left-over entities - */ - Pair, List> selectUpdateCandidate( - Collection cacheMissEntities, - String detectorId, - AnomalyDetector detector - ); - - /** - * - * @param detector Detector config - * @param toUpdate Model state candidate - * @return if we can host the given model state - */ - boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate); - - /** - * - * @param detectorId Detector Id - * @return a detector's model information - */ - List getAllModelProfile(String detectorId); - - /** - * Gets an entity's model sizes - * - * @param detectorId Detector Id - * @param entityModelId Entity's model Id - * @return the entity's memory size - */ - Optional getModelProfile(String detectorId, String entityModelId); - - /** - * Get a model state without incurring priority update. Used in maintenance. - * @param detectorId Detector Id - * @param modelId Model Id - * @return Model state - */ - Optional> getForMaintainance(String detectorId, String modelId); - - /** - * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. - * @param detectorId Detector Id - * @param entityModelId Model Id - */ - void removeEntityModel(String detectorId, String entityModelId); -} diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/ADCheckpointIndexRetention.java b/src/main/java/org/opensearch/ad/cluster/diskcleanup/ADCheckpointIndexRetention.java new file mode 100644 index 000000000..6cf8c2385 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/diskcleanup/ADCheckpointIndexRetention.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.cluster.diskcleanup; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; + +public class ADCheckpointIndexRetention extends BaseModelCheckpointIndexRetention { + + public ADCheckpointIndexRetention(Duration defaultCheckpointTtl, Clock clock, IndexCleanup indexCleanup) { + super(defaultCheckpointTtl, clock, indexCleanup, ADCommonName.CHECKPOINT_INDEX_NAME); + } + +} diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java index 31adb1dac..d782c5e4c 100644 --- a/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java @@ -21,17 +21,11 @@ public class ADCommonMessages { public static final String FEATURE_NOT_AVAILABLE_ERR_MSG = "No Feature in current detection window."; public static final String DISABLED_ERR_MSG = "AD functionality is disabled. To enable update plugins.anomaly_detection.enabled to true"; - public static String FAIL_TO_PARSE_DETECTOR_MSG = "Fail to parse detector with id: "; - public static String FAIL_TO_GET_PROFILE_MSG = "Fail to get profile for detector "; - public static String FAIL_TO_GET_TOTAL_ENTITIES = "Failed to get total entities for detector "; public static String CATEGORICAL_FIELD_NUMBER_SURPASSED = "We don't support categorical fields more than "; - public static String EMPTY_PROFILES_COLLECT = "profiles to collect are missing or invalid"; - public static String FAIL_FETCH_ERR_MSG = "Fail to fetch profile for "; public static String DETECTOR_IS_RUNNING = "Detector is already running"; public static String DETECTOR_MISSING = "Detector is missing"; public static String AD_TASK_ACTION_MISSING = "AD task action is missing"; public static final String INDEX_NOT_FOUND = "index does not exist"; - public static final String NOT_EXISTENT_VALIDATION_TYPE = "The given validation type doesn't exist"; public static final String UNSUPPORTED_PROFILE_TYPE = "Unsupported profile types"; public static final String REQUEST_THROTTLED_MSG = "Request throttled. Please try again later."; @@ -41,13 +35,11 @@ public class ADCommonMessages { public static String EXCEED_HISTORICAL_ANALYSIS_LIMIT = "Exceed max historical analysis limit per node"; public static String NO_ELIGIBLE_NODE_TO_RUN_DETECTOR = "No eligible node to run detector "; public static String EMPTY_STALE_RUNNING_ENTITIES = "Empty stale running entities"; - public static String CAN_NOT_FIND_LATEST_TASK = "can't find latest task"; public static String NO_ENTITY_FOUND = "No entity found"; public static String HISTORICAL_ANALYSIS_CANCELLED = "Historical analysis cancelled by user"; public static String HC_DETECTOR_TASK_IS_UPDATING = "HC detector task is updating"; public static String INVALID_TIME_CONFIGURATION_UNITS = "Time unit %s is not supported"; public static String FAIL_TO_GET_DETECTOR = "Fail to get detector"; - public static String FAIL_TO_GET_DETECTOR_INFO = "Fail to get detector info"; public static String FAIL_TO_CREATE_DETECTOR = "Fail to create detector"; public static String FAIL_TO_UPDATE_DETECTOR = "Fail to update detector"; public static String FAIL_TO_PREVIEW_DETECTOR = "Fail to preview detector"; @@ -55,27 +47,6 @@ public class ADCommonMessages { public static String FAIL_TO_STOP_DETECTOR = "Fail to stop detector"; public static String FAIL_TO_DELETE_DETECTOR = "Fail to delete detector"; public static String FAIL_TO_DELETE_AD_RESULT = "Fail to delete anomaly result"; - public static String FAIL_TO_GET_STATS = "Fail to get stats"; - public static String FAIL_TO_SEARCH = "Fail to search"; - - public static String WINDOW_DELAY_REC = - "Latest seen data point is at least %d minutes ago, consider changing window delay to at least %d minutes."; - public static String TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA = - "There isn't enough historical data found with current timefield selected."; - public static String DETECTOR_INTERVAL_REC = - "The selected detector interval might collect sparse data. Consider changing interval length to: "; - public static String RAW_DATA_TOO_SPARSE = - "Source index data is potentially too sparse for model training. Consider changing interval length or ingesting more data"; - public static String MODEL_VALIDATION_FAILED_UNEXPECTEDLY = "Model validation experienced issues completing."; - public static String FILTER_QUERY_TOO_SPARSE = "Data is too sparse after data filter is applied. Consider changing the data filter"; - public static String CATEGORY_FIELD_TOO_SPARSE = - "Data is most likely too sparse with the given category fields. Consider revising category field/s or ingesting more data "; - public static String CATEGORY_FIELD_NO_DATA = - "No entity was found with the given categorical fields. Consider revising category field/s or ingesting more data"; - public static String FEATURE_QUERY_TOO_SPARSE = - "Data is most likely too sparse when given feature queries are applied. Consider revising feature queries."; - public static String TIMEOUT_ON_INTERVAL_REC = "Timed out getting interval recommendation"; - public static final String NO_MODEL_ERR_MSG = "No RCF models are available either because RCF" + " models are not ready or all nodes are unresponsive or the system might have bugs."; public static String INVALID_RESULT_INDEX_PREFIX = "Result index must start with " + CUSTOM_RESULT_INDEX_PREFIX; diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonName.java b/src/main/java/org/opensearch/ad/constant/ADCommonName.java index 3a97db889..260d162f1 100644 --- a/src/main/java/org/opensearch/ad/constant/ADCommonName.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonName.java @@ -11,8 +11,6 @@ package org.opensearch.ad.constant; -import org.opensearch.timeseries.stats.StatNames; - public class ADCommonName { // ====================================== // Index name @@ -25,46 +23,11 @@ public class ADCommonName { // The alias of the index in which to write AD result history public static final String ANOMALY_RESULT_INDEX_ALIAS = ".opendistro-anomaly-results"; - // ====================================== - // Format name - // ====================================== - public static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; - // ====================================== // Anomaly Detector name for X-Opaque-Id header // ====================================== public static final String ANOMALY_DETECTOR = "[Anomaly Detector]"; - // ====================================== - // Ultrawarm node attributes - // ====================================== - - // hot node - public static String HOT_BOX_TYPE = "hot"; - - // warm node - public static String WARM_BOX_TYPE = "warm"; - - // box type - public static final String BOX_TYPE_KEY = "box_type"; - - // ====================================== - // Profile name - // ====================================== - public static final String STATE = "state"; - public static final String ERROR = "error"; - public static final String COORDINATING_NODE = "coordinating_node"; - public static final String SHINGLE_SIZE = "shingle_size"; - public static final String TOTAL_SIZE_IN_BYTES = "total_size_in_bytes"; - public static final String MODELS = "models"; - public static final String MODEL = "model"; - public static final String INIT_PROGRESS = "init_progress"; - public static final String CATEGORICAL_FIELD = "category_field"; - public static final String TOTAL_ENTITIES = "total_entities"; - public static final String ACTIVE_ENTITIES = "active_entities"; - public static final String ENTITY_INFO = "entity_info"; - public static final String TOTAL_UPDATES = "total_updates"; - public static final String MODEL_COUNT = StatNames.MODEL_COUNT.getName(); // ====================================== // Historical detectors // ====================================== @@ -87,11 +50,8 @@ public class ADCommonName { public static final String CONFIDENCE_JSON_KEY = "confidence"; public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; public static final String QUEUE_JSON_KEY = "queue"; - // ====================================== - // Used for backward-compatibility in messaging - // ====================================== - public static final String EMPTY_FIELD = ""; + // ====================================== // Validation // ====================================== // detector validation aspect diff --git a/src/main/java/org/opensearch/ad/constant/CommonValue.java b/src/main/java/org/opensearch/ad/constant/ADCommonValue.java similarity index 81% rename from src/main/java/org/opensearch/ad/constant/CommonValue.java rename to src/main/java/org/opensearch/ad/constant/ADCommonValue.java index f5d5b15eb..91b9f72f7 100644 --- a/src/main/java/org/opensearch/ad/constant/CommonValue.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonValue.java @@ -11,9 +11,7 @@ package org.opensearch.ad.constant; -public class CommonValue { - // unknown or no schema version - public static Integer NO_SCHEMA_VERSION = 0; +public class ADCommonValue { public static String INTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/adinternal/"; public static String EXTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/ad/"; } diff --git a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java b/src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java similarity index 58% rename from src/main/java/org/opensearch/ad/ml/CheckpointDao.java rename to src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java index adb097cb6..a4b9cfb7d 100644 --- a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java +++ b/src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java @@ -15,42 +15,24 @@ import java.security.AccessController; import java.security.PrivilegedAction; import java.time.Clock; -import java.time.Duration; import java.time.Instant; import java.time.ZoneOffset; import java.time.ZonedDateTime; -import java.util.AbstractMap.SimpleImmutableEntry; -import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Arrays; import java.util.Base64; +import java.util.Deque; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; -import org.opensearch.ResourceAlreadyExistsException; -import org.opensearch.action.bulk.BulkAction; -import org.opensearch.action.bulk.BulkItemResponse; -import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.bulk.BulkResponse; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.get.MultiGetAction; -import org.opensearch.action.get.MultiGetRequest; -import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.action.update.UpdateRequest; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; @@ -58,13 +40,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.index.reindex.BulkByScrollResponse; -import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; -import org.opensearch.index.reindex.ScrollableHitSource; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.util.ClientUtil; @@ -88,29 +70,18 @@ /** * DAO for model checkpoints. */ -public class CheckpointDao { - - private static final Logger logger = LogManager.getLogger(CheckpointDao.class); - static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; - static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; - static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; - static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; - static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; - static final String NOT_ABLE_TO_DELETE_LOG_MSG = "Cannot delete all checkpoints of detector"; +public class ADCheckpointDao extends CheckpointDao { + private static final Logger logger = LogManager.getLogger(ADCheckpointDao.class); + // ====================================== + // Model serialization/deserialization + // ====================================== public static final String ENTITY_RCF = "rcf"; public static final String ENTITY_THRESHOLD = "th"; public static final String ENTITY_TRCF = "trcf"; public static final String FIELD_MODELV2 = "modelV2"; public static final String DETECTOR_ID = "detectorId"; - // dependencies - private final Client client; - private final ClientUtil clientUtil; - - // configuration - private final String indexName; - private Gson gson; private RandomCutForestMapper mapper; @@ -129,11 +100,7 @@ public class CheckpointDao { private final ADIndexManagement indexUtil; private final JsonParser parser = new JsonParser(); - // we won't read/write a checkpoint larger than a threshold - private final int maxCheckpointBytes; - private final GenericObjectPool serializeRCFBufferPool; - private final int serializeRCFBufferSize; // anomaly rate private double anomalyRate; @@ -142,7 +109,6 @@ public class CheckpointDao { * * @param client ES search client * @param clientUtil utility with ES client - * @param indexName name of the index for model checkpoints * @param gson accessor to Gson functionality * @param mapper RCF model serialization utility * @param converter converter from rcf v1 serde to protostuff based format @@ -155,10 +121,9 @@ public class CheckpointDao { * @param serializeRCFBufferSize the size of the buffer for RCF serialization * @param anomalyRate anomaly rate */ - public CheckpointDao( + public ADCheckpointDao( Client client, ClientUtil clientUtil, - String indexName, Gson gson, RandomCutForestMapper mapper, V1JsonToV3StateConverter converter, @@ -169,32 +134,29 @@ public CheckpointDao( int maxCheckpointBytes, GenericObjectPool serializeRCFBufferPool, int serializeRCFBufferSize, - double anomalyRate + double anomalyRate, + Clock clock ) { - this.client = client; - this.clientUtil = clientUtil; - this.indexName = indexName; - this.gson = gson; + super( + client, + clientUtil, + ADCommonName.CHECKPOINT_INDEX_NAME, + gson, + maxCheckpointBytes, + serializeRCFBufferPool, + serializeRCFBufferSize, + indexUtil, + clock + ); this.mapper = mapper; this.converter = converter; this.trcfMapper = trcfMapper; this.trcfSchema = trcfSchema; this.thresholdingModelClass = thresholdingModelClass; this.indexUtil = indexUtil; - this.maxCheckpointBytes = maxCheckpointBytes; - this.serializeRCFBufferPool = serializeRCFBufferPool; - this.serializeRCFBufferSize = serializeRCFBufferSize; this.anomalyRate = anomalyRate; } - private void putModelCheckpoint(String modelId, Map source, ActionListener listener) { - if (indexUtil.doesCheckpointIndexExist()) { - saveModelCheckpointAsync(source, modelId, listener); - } else { - onCheckpointNotExist(source, modelId, listener); - } - } - /** * Puts a rcf model checkpoint in the storage. * @@ -229,58 +191,25 @@ public void putThresholdCheckpoint(String modelId, ThresholdingModel threshold, putModelCheckpoint(modelId, source, listener); } - private void onCheckpointNotExist(Map source, String modelId, ActionListener listener) { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - saveModelCheckpointAsync(source, modelId, listener); - } else { - throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - saveModelCheckpointAsync(source, modelId, listener); - } else { - logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); - } - })); - } - - /** - * Update the model doc using fields in source. This ensures we won't touch - * the old checkpoint and nodes with old/new logic can coexist in a cluster. - * This is useful for introducing compact rcf new model format. - * - * @param source fields to update - * @param modelId model Id, used as doc id in the checkpoint index - * @param listener Listener to return response - */ - private void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { - - UpdateRequest updateRequest = new UpdateRequest(indexName, modelId); - updateRequest.doc(source); - // If the document does not already exist, the contents of the upsert element are inserted as a new document. - // If the document exists, update fields in the map - updateRequest.docAsUpsert(true); - clientUtil - .asyncRequest( - updateRequest, - client::update, - ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) - ); - } - /** * Prepare for index request using the contents of the given model state * @param modelState an entity model state * @return serialized JSON map or empty map if the state is too bloated * @throws IOException when serialization fails */ - public Map toIndexSource(ModelState modelState) throws IOException { + @Override + public Map toIndexSource(ModelState modelState) throws IOException { String modelId = modelState.getModelId(); Map source = new HashMap<>(); - EntityModel model = modelState.getModel(); - Optional serializedModel = toCheckpoint(model, modelId); + + Object model = modelState.getModel(); + if (modelState.getEntity().isEmpty()) { + throw new IllegalArgumentException("Excpect model state to be an entity model"); + } + + ThresholdedRandomCutForest entityModel = (ThresholdedRandomCutForest) model; + + Optional serializedModel = toCheckpoint(entityModel, modelId); if (!serializedModel.isPresent() || serializedModel.get().length() > maxCheckpointBytes) { logger .warn( @@ -292,13 +221,25 @@ public Map toIndexSource(ModelState modelState) thr ); return source; } - String detectorId = modelState.getId(); + source.put(FIELD_MODELV2, serializedModel.get()); + + if (modelState.getSamples() != null && !(modelState.getSamples().isEmpty())) { + source.put(CommonName.ENTITY_SAMPLE_QUEUE, toCheckpoint(modelState.getSamples()).get()); + } + + // if there are no samples and no model, no need to index as other information are meta data + if (!source.containsKey(CommonName.ENTITY_SAMPLE_QUEUE) && !source.containsKey(FIELD_MODELV2)) { + return source; + } + + String detectorId = modelState.getConfigId(); source.put(DETECTOR_ID, detectorId); // we cannot pass Optional as OpenSearch does not know how to serialize an Optional value - source.put(FIELD_MODELV2, serializedModel.get()); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); - source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); - Optional entity = model.getEntity(); + source.put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); + + Optional entity = modelState.getEntity(); if (entity.isPresent()) { source.put(CommonName.ENTITY_KEY, entity.get()); } @@ -312,7 +253,7 @@ public Map toIndexSource(ModelState modelState) thr * @param modelId model id * @return serialized string */ - public Optional toCheckpoint(EntityModel model, String modelId) { + public Optional toCheckpoint(ThresholdedRandomCutForest model, String modelId) { return AccessController.doPrivileged((PrivilegedAction>) () -> { if (model == null) { logger.warn("Empty model"); @@ -320,11 +261,8 @@ public Optional toCheckpoint(EntityModel model, String modelId) { } try { JsonObject json = new JsonObject(); - if (model.getSamples() != null && !(model.getSamples().isEmpty())) { - json.add(CommonName.ENTITY_SAMPLE, gson.toJsonTree(model.getSamples())); - } - if (model.getTrcf().isPresent()) { - json.addProperty(ENTITY_TRCF, toCheckpoint(model.getTrcf().get())); + if (model != null) { + json.addProperty(ENTITY_TRCF, toCheckpoint(model)); } // if json is empty, it will be an empty Json string {}. No need to save it on disk. return json.entrySet().isEmpty() ? Optional.empty() : Optional.ofNullable(gson.toJson(json)); @@ -335,7 +273,7 @@ public Optional toCheckpoint(EntityModel model, String modelId) { }); } - private String toCheckpoint(ThresholdedRandomCutForest trcf) { + String toCheckpoint(ThresholdedRandomCutForest trcf) { String checkpoint = null; Map.Entry result = checkoutOrNewBuffer(); LinkedBuffer buffer = result.getKey(); @@ -369,21 +307,6 @@ private String toCheckpoint(ThresholdedRandomCutForest trcf) { return checkpoint; } - private Map.Entry checkoutOrNewBuffer() { - LinkedBuffer buffer = null; - boolean isCheckout = true; - try { - buffer = serializeRCFBufferPool.borrowObject(); - } catch (Exception e) { - logger.warn("Failed to borrow a buffer from pool", e); - } - if (buffer == null) { - buffer = LinkedBuffer.allocate(serializeRCFBufferSize); - isCheckout = false; - } - return new SimpleImmutableEntry(buffer, isCheckout); - } - private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer) { try { byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> { @@ -396,73 +319,6 @@ private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer } } - /** - * Deletes the model checkpoint for the model. - * - * @param modelId id of the model - * @param listener onReponse is called with null when the operation is completed - */ - public void deleteModelCheckpoint(String modelId, ActionListener listener) { - clientUtil - .asyncRequest( - new DeleteRequest(indexName, modelId), - client::delete, - ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) - ); - } - - /** - * Delete checkpoints associated with a detector. Used in multi-entity detector. - * @param detectorID Detector Id - */ - public void deleteModelCheckpointByDetectorId(String detectorID) { - // A bulk delete request is performed for each batch of matching documents. If a - // search or bulk request is rejected, the requests are retried up to 10 times, - // with exponential back off. If the maximum retry limit is reached, processing - // halts and all failed requests are returned in the response. Any delete - // requests that completed successfully still stick, they are not rolled back. - DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(ADCommonName.CHECKPOINT_INDEX_NAME) - .setQuery(new MatchQueryBuilder(DETECTOR_ID, detectorID)) - .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) - .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. - // Retry in this case - .setRequestsPerSecond(500); // throttle delete requests - logger.info("Delete checkpoints of detector {}", detectorID); - client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { - if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { - logFailure(response, detectorID); - } - // can return 0 docs get deleted because: - // 1) we cannot find matching docs - // 2) bad stats from OpenSearch. In this case, docs are deleted, but - // OpenSearch says deleted is 0. - logger.info("{} " + DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); - }, exception -> { - if (exception instanceof IndexNotFoundException) { - logger.info(INDEX_DELETED_LOG_MSG + " {}", detectorID); - } else { - // Gonna eventually delete in daily cron. - logger.error(NOT_ABLE_TO_DELETE_LOG_MSG, exception); - } - })); - } - - private void logFailure(BulkByScrollResponse response, String detectorID) { - if (response.isTimedOut()) { - logger.warn(TIMEOUT_LOG_MSG + " {}", detectorID); - } else if (!response.getBulkFailures().isEmpty()) { - logger.warn(BULK_FAILURE_LOG_MSG + " {}", detectorID); - for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { - logger.warn(bulkFailure); - } - } else { - logger.warn(SEARCH_FAILURE_LOG_MSG + " {}", detectorID); - for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { - logger.warn(searchFailure); - } - } - } - /** * Load json checkpoint into models * @@ -471,9 +327,14 @@ private void logFailure(BulkByScrollResponse response, String detectorID) { * @return a pair of entity model and its last checkpoint time; or empty if * the raw checkpoint is too large */ - public Optional> fromEntityModelCheckpoint(Map checkpoint, String modelId) { + @Override + protected ModelState fromEntityModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ) { try { - return AccessController.doPrivileged((PrivilegedAction>>) () -> { + return AccessController.doPrivileged((PrivilegedAction>) () -> { Object modelObj = checkpoint.get(FIELD_MODELV2); if (modelObj == null) { // in case there is old -format checkpoint @@ -481,24 +342,14 @@ public Optional> fromEntityModelCheckpoint(Map maxCheckpointBytes) { logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length())); - return Optional.empty(); + return null; } JsonObject json = parser.parse(model).getAsJsonObject(); - ArrayDeque samples = null; - if (json.has(CommonName.ENTITY_SAMPLE)) { - // verified, don't need privileged call to get permission - samples = new ArrayDeque<>( - Arrays.asList(this.gson.fromJson(json.getAsJsonArray(CommonName.ENTITY_SAMPLE), new double[0][0].getClass())) - ); - } else { - // avoid possible null pointer exception - samples = new ArrayDeque<>(); - } ThresholdedRandomCutForest trcf = null; if (json.has(ENTITY_TRCF)) { @@ -518,15 +369,19 @@ public Optional> fromEntityModelCheckpoint(Map convertedTRCF = convertToTRCF(rcf, threshold); - // if checkpoint is corrupted (e.g., some unexpected checkpoint when we missed - // the mark in backward compatibility), we are not gonna load the model part - // the model will have to use live data to initialize - if (convertedTRCF.isPresent()) { - trcf = convertedTRCF.get(); + if (rcf.isPresent()) { + Optional convertedTRCF = convertToTRCF(rcf.get(), threshold); + // if checkpoint is corrupted (e.g., some unexpected checkpoint when we missed + // the mark in backward compatibility), we are not gonna load the model part + // the model will have to use live data to initialize + if (convertedTRCF.isPresent()) { + trcf = convertedTRCF.get(); + } } } + Deque sampleQueue = loadSampleQueue(checkpoint, modelId); + String lastCheckpointTimeString = (String) (checkpoint.get(CommonName.TIMESTAMP)); Instant timestamp = Instant.parse(lastCheckpointTimeString); Entity entity = null; @@ -538,14 +393,27 @@ public Optional> fromEntityModelCheckpoint(Map(entityModel, timestamp)); + + ModelState modelState = new ModelState( + trcf, + modelId, + configId, + ModelManager.ModelType.TRCF.getName(), + clock, + 0, + // TODO: track last processed sample in AD + new Sample(), + Optional.ofNullable(entity), + sampleQueue + ); + modelState.setLastCheckpointTime(timestamp); + return modelState; }); } catch (Exception e) { logger.warn("Exception while deserializing checkpoint " + modelId, e); // checkpoint corrupted (e.g., a checkpoint not recognized by current code // due to bugs). Better redo training. - return Optional.empty(); + return null; } } @@ -604,7 +472,7 @@ private void deserializeTRCFModel( String thresholdingModelId = SingleStreamModelIdMapper.getThresholdModelIdFromRCFModelId(rcfModelId); // query for threshold model and combinne rcf and threshold model into a ThresholdedRandomCutForest getThresholdModel(thresholdingModelId, ActionListener.wrap(thresholdingModel -> { - listener.onResponse(convertToTRCF(forest, thresholdingModel)); + listener.onResponse(convertToTRCF(forest.get(), thresholdingModel)); }, listener::onFailure)); } } catch (Exception e) { @@ -616,30 +484,14 @@ private void deserializeTRCFModel( } } - /** - * Read a checkpoint from the index and return the EntityModel object - * @param modelId Model Id - * @param listener Listener to return a pair of entity model and its last checkpoint time - */ - public void deserializeModelCheckpoint(String modelId, ActionListener>> listener) { - clientUtil.asyncRequest(new GetRequest(indexName, modelId), client::get, ActionListener.wrap(response -> { - listener.onResponse(processGetResponse(response, modelId)); - }, listener::onFailure)); - } - - /** - * Process a checkpoint GetResponse and return the EntityModel object - * @param response Checkpoint Index GetResponse - * @param modelId Model Id - * @return a pair of entity model and its last checkpoint time - */ - public Optional> processGetResponse(GetResponse response, String modelId) { - Optional> checkpointString = processRawCheckpoint(response); - if (checkpointString.isPresent()) { - return fromEntityModelCheckpoint(checkpointString.get(), modelId); - } else { - return Optional.empty(); - } + @Override + protected ModelState fromSingleStreamModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ) { + // single stream AD code path is still using old way + throw new UnsupportedOperationException("This method is not supported"); } /** @@ -703,39 +555,8 @@ private Optional processThresholdModelCheckpoint(GetResponse response) { .map(source -> source.get(CommonName.FIELD_MODEL)); } - private Optional> processRawCheckpoint(GetResponse response) { - return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); - } - - public void batchRead(MultiGetRequest request, ActionListener listener) { - clientUtil.execute(MultiGetAction.INSTANCE, request, listener); - } - - public void batchWrite(BulkRequest request, ActionListener listener) { - if (indexUtil.doesCheckpointIndexExist()) { - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - // create index failure. Notify callers using listener. - listener.onFailure(new TimeSeriesException("Creating checkpoint with mappings call not acknowledged.")); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); - listener.onFailure(exception); - } - })); - } - } - - private Optional convertToTRCF(Optional rcf, Optional kllThreshold) { - if (!rcf.isPresent()) { + private Optional convertToTRCF(RandomCutForest rcf, Optional kllThreshold) { + if (rcf == null) { return Optional.empty(); } // if there is no threshold model (e.g., threshold model is deleted by HourlyCron), we are gonna @@ -744,20 +565,16 @@ private Optional convertToTRCF(Optional { + private static final Logger logger = LogManager.getLogger(ADEntityColdStart.class); + + /** + * Constructor + * + * @param clock UTC clock + * @param threadPool Accessor to different threadpools + * @param nodeStateManager Storing node state + * @param rcfSampleSize The sample size used by stream samplers in this forest + * @param numberOfTrees The number of trees in this forest. + * @param rcfTimeDecay rcf samples time decay constant + * @param numMinSamples The number of points required by stream samplers before + * results are returned. + * @param defaultSampleStride default sample distances measured in detector intervals. + * @param defaultTrainSamples Default train samples to collect. + * @param searchFeatureDao Used to issue OS queries. + * @param thresholdMinPvalue min P-value for thresholding + * @param featureManager Used to create features for models. + * @param modelTtl time-to-live before last access time of the cold start cache. + * We have a cache to record entities that have run cold starts to avoid + * repeated unsuccessful cold start. + * @param checkpointWriteWorker queue to insert model checkpoints + * @param rcfSeed rcf random seed + * @param maxRoundofColdStart max number of rounds of cold start + * @param coolDownMinutes cool down minutes when OpenSearch is overloaded + */ + public ADEntityColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + int defaultSampleStride, + int defaultTrainSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ADCheckpointWriteWorker checkpointWriteWorker, + long rcfSeed, + int maxRoundofColdStart, + int coolDownMinutes + ) { + super( + modelTtl, + coolDownMinutes, + clock, + threadPool, + numMinSamples, + checkpointWriteWorker, + rcfSeed, + numberOfTrees, + rcfSampleSize, + thresholdMinPvalue, + rcfTimeDecay, + nodeStateManager, + defaultSampleStride, + defaultTrainSamples, + searchFeatureDao, + featureManager, + maxRoundofColdStart, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + AnalysisType.AD + ); + } + + public ADEntityColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + int maxSampleStride, + int maxTrainSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ADCheckpointWriteWorker checkpointWriteQueue, + int maxRoundofColdStart, + int coolDownMinutes + ) { + this( + clock, + threadPool, + nodeStateManager, + rcfSampleSize, + numberOfTrees, + rcfTimeDecay, + numMinSamples, + maxSampleStride, + maxTrainSamples, + searchFeatureDao, + thresholdMinPvalue, + featureManager, + modelTtl, + checkpointWriteQueue, + -1, + maxRoundofColdStart, + coolDownMinutes + ); + } + + /** + * Train model using given data points and save the trained model. + * + * @param pointSamplePair A pair consisting of a queue of continuous data points, + * in ascending order of timestamps and last seen sample. + * @param entity Entity instance + * @param entityState Entity state associated with the model Id + */ + @Override + protected List> trainModelFromDataSegments( + Pair>, Sample> pointSamplePair, + Optional entity, + ModelState entityState, + Config config, + String taskId + ) { + if (entity.isEmpty()) { + throw new IllegalArgumentException("We offer only HC cold start"); + } + + List> dataPoints = pointSamplePair.getKey(); + if (dataPoints == null || dataPoints.size() == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + double[] firstPoint = dataPoints.get(0).getValue(); + if (firstPoint == null || firstPoint.length == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + int shingleSize = config.getShingleSize(); + int dimensions = firstPoint.length * shingleSize; + ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest + .builder() + .dimensions(dimensions) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .initialAcceptFraction(initialAcceptFraction) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + // same with dimension for opportunistic memory saving + // Usually, we use it as shingleSize(dimension). When a new point comes in, we will + // look at the point store if there is any overlapping. Say the previously-stored + // vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize + // overlapping x3, x4, and only store x5, x6. + .shingleSize(shingleSize) + .internalShinglingEnabled(true) + .anomalyRate(1 - this.thresholdMinPvalue) + .transformMethod(TransformMethod.NORMALIZE) + .alertOnce(true) + .autoAdjust(true); + + if (rcfSeed > 0) { + rcfBuilder.randomSeed(rcfSeed); + } + ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest(rcfBuilder); + + for (int i = 0; i < dataPoints.size(); i++) { + double[] dataValue = dataPoints.get(i).getValue(); + trcf.process(dataValue, 0); + } + + entityState.setModel(trcf); + + entityState.setLastUsedTime(clock.instant()); + entityState.setLastProcessedSample(pointSamplePair.getValue()); + + // save to checkpoint + checkpointWriteWorker.write(entityState, true, RequestPriority.MEDIUM); + + return dataPoints; + } + + @Override + protected boolean isInterpolationInColdStartEnabled() { + return ADEnabledSetting.isInterpolationInColdStartEnabled(); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelManager.java b/src/main/java/org/opensearch/ad/ml/ADModelManager.java similarity index 71% rename from src/main/java/org/opensearch/ad/ml/ModelManager.java rename to src/main/java/org/opensearch/ad/ml/ADModelManager.java index 14f935aae..0f4c00bda 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelManager.java +++ b/src/main/java/org/opensearch/ad/ml/ADModelManager.java @@ -14,7 +14,6 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; @@ -23,7 +22,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -31,22 +29,28 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.DetectorModelSize; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.DateUtils; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.AnalysisModelSize; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; -import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DateUtils; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.config.Precision; @@ -57,51 +61,27 @@ /** * A facade managing ML operations and models. */ -public class ModelManager implements DetectorModelSize { +public class ADModelManager extends + ModelManager + implements + AnalysisModelSize { protected static final String ENTITY_SAMPLE = "sp"; protected static final String ENTITY_RCF = "rcf"; protected static final String ENTITY_THRESHOLD = "th"; - public enum ModelType { - RCF("rcf"), - THRESHOLD("threshold"), - ENTITY("entity"); - - private String name; - - ModelType(String name) { - this.name = name; - } - - public String getName() { - return name; - } - } - - private static final Logger logger = LogManager.getLogger(ModelManager.class); + private static final Logger logger = LogManager.getLogger(ADModelManager.class); // states - private TRCFMemoryAwareConcurrentHashmap forests; + private MemoryAwareConcurrentHashmap forests; private Map> thresholds; // configuration - private final int rcfNumTrees; - private final int rcfNumSamplesInTree; - private final double rcfTimeDecay; - private final int rcfNumMinSamples; + private final double thresholdMinPvalue; private final int minPreviewSize; private final Duration modelTtl; private Duration checkpointInterval; - // dependencies - private final CheckpointDao checkpointDao; - private final Clock clock; - public FeatureManager featureManager; - - private EntityColdStarter entityColdStarter; - private MemoryTracker memoryTracker; - private final double initialAcceptFraction; /** @@ -123,8 +103,8 @@ public String getName() { * @param settings Node settings * @param clusterService Cluster service accessor */ - public ModelManager( - CheckpointDao checkpointDao, + public ADModelManager( + ADCheckpointDao checkpointDao, Clock clock, int rcfNumTrees, int rcfNumSamplesInTree, @@ -134,18 +114,24 @@ public ModelManager( int minPreviewSize, Duration modelTtl, Setting checkpointIntervalSetting, - EntityColdStarter entityColdStarter, + ADEntityColdStart entityColdStarter, FeatureManager featureManager, MemoryTracker memoryTracker, Settings settings, ClusterService clusterService ) { - this.checkpointDao = checkpointDao; - this.clock = clock; - this.rcfNumTrees = rcfNumTrees; - this.rcfNumSamplesInTree = rcfNumSamplesInTree; - this.rcfTimeDecay = rcfTimeDecay; - this.rcfNumMinSamples = rcfNumMinSamples; + super( + rcfNumTrees, + rcfNumSamplesInTree, + rcfTimeDecay, + rcfNumMinSamples, + entityColdStarter, + memoryTracker, + clock, + featureManager, + checkpointDao + ); + this.thresholdMinPvalue = thresholdMinPvalue; this.minPreviewSize = minPreviewSize; this.modelTtl = modelTtl; @@ -156,12 +142,9 @@ public ModelManager( .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); } - this.forests = new TRCFMemoryAwareConcurrentHashmap<>(memoryTracker); + this.forests = new MemoryAwareConcurrentHashmap<>(memoryTracker); this.thresholds = new ConcurrentHashMap<>(); - this.entityColdStarter = entityColdStarter; - this.featureManager = featureManager; - this.memoryTracker = memoryTracker; this.initialAcceptFraction = rcfNumMinSamples * 1.0d / rcfNumSamplesInTree; } @@ -198,10 +181,14 @@ private void getTRcfResult( ) { modelState.setLastUsedTime(clock.instant()); - ThresholdedRandomCutForest trcf = modelState.getModel(); + Optional trcfOptional = modelState.getModel(); + if (trcfOptional.isEmpty()) { + listener.onFailure(new TimeSeriesException("empty model")); + return; + } try { - AnomalyDescriptor result = trcf.process(point, 0); - double[] attribution = normalizeAttribution(trcf.getForest(), result.getRelevantAttribution()); + AnomalyDescriptor result = trcfOptional.get().process(point, 0); + double[] attribution = normalizeAttribution(trcfOptional.get().getForest(), result.getRelevantAttribution()); listener .onResponse( new ThresholdingResult( @@ -276,7 +263,7 @@ private double[] createEmptyAttribution(RandomCutForest forest) { return new double[baseDimensions]; } - private Optional> restoreModelState( + Optional> restoreModelState( Optional rcfModel, String modelId, String detectorId @@ -286,7 +273,16 @@ private Optional> restoreModelState( } return rcfModel .filter(rcf -> memoryTracker.isHostingAllowed(detectorId, rcf)) - .map(rcf -> ModelState.createSingleEntityModelState(rcf, modelId, detectorId, ModelType.RCF.getName(), clock)); + .map( + rcf -> new ModelState( + rcf, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + null + ) + ); } private void processRestoredTRcf( @@ -320,13 +316,16 @@ private void processRestoredCheckpoint( ) { logger.info("Restoring checkpoint for {}", modelId); Optional> model = restoreModelState(checkpointModel, modelId, detectorId); - if (model.isPresent()) { - forests.put(modelId, model.get()); - if (model.get().getModel() != null && model.get().getModel().getForest() != null) - listener.onResponse(model.get().getModel().getForest().getTotalUpdates()); - } else { - listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId)); - } + model.ifPresentOrElse(modelState -> { + forests.put(modelId, modelState); + modelState.getModel().ifPresent(trcf -> { + if (trcf.getForest() != null) { + listener.onResponse(trcf.getForest().getTotalUpdates()); + } else { + listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId)); + } + }); + }, () -> listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId))); } /** @@ -356,14 +355,26 @@ private void getThresholdingResult( double score, ActionListener listener ) { - ThresholdingModel threshold = modelState.getModel(); - double grade = threshold.grade(score); - double confidence = threshold.confidence(); - if (score > 0) { - threshold.update(score); + Optional thresholdOptional = modelState.getModel(); + if (thresholdOptional.isPresent()) { + ThresholdingModel threshold = thresholdOptional.get(); + double grade = threshold.grade(score); + double confidence = threshold.confidence(); + if (score > 0) { + threshold.update(score); + } + modelState.setLastUsedTime(clock.instant()); + listener.onResponse(new ThresholdingResult(grade, confidence, score)); + } else { + listener + .onFailure( + new ResourceNotFoundException( + modelState.getConfigId(), + ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelState.getModelId() + ) + ); } - modelState.setLastUsedTime(clock.instant()); - listener.onResponse(new ThresholdingResult(grade, confidence, score)); + } private void processThresholdCheckpoint( @@ -374,9 +385,7 @@ private void processThresholdCheckpoint( ActionListener listener ) { Optional> model = thresholdModel - .map( - threshold -> ModelState.createSingleEntityModelState(threshold, modelId, detectorId, ModelType.THRESHOLD.getName(), clock) - ); + .map(threshold -> new ModelState<>(threshold, modelId, detectorId, ModelManager.ModelType.THRESHOLD.getName(), clock, null)); if (model.isPresent()) { thresholds.put(modelId, model.get()); getThresholdingResult(model.get(), score, listener); @@ -424,8 +433,8 @@ private void stopModel(Map> models, String modelId, Ac Optional> modelState = Optional .ofNullable(models.remove(modelId)) .filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)); - if (modelState.isPresent()) { - T model = modelState.get().getModel(); + if (modelState.isPresent() && modelState.get().getModel().isPresent()) { + T model = modelState.get().getModel().get(); if (model instanceof ThresholdedRandomCutForest) { checkpointDao .putTRCFCheckpoint( @@ -460,29 +469,6 @@ public void clear(String detectorId, ActionListener listener) { clearModels(detectorId, forests, ActionListener.wrap(r -> clearModels(detectorId, thresholds, listener), listener::onFailure)); } - private void clearModels(String detectorId, Map models, ActionListener listener) { - Iterator id = models.keySet().iterator(); - clearModelForIterator(detectorId, models, id, listener); - } - - private void clearModelForIterator(String detectorId, Map models, Iterator idIter, ActionListener listener) { - if (idIter.hasNext()) { - String modelId = idIter.next(); - if (SingleStreamModelIdMapper.getDetectorIdForModelId(modelId).equals(detectorId)) { - models.remove(modelId); - checkpointDao - .deleteModelCheckpoint( - modelId, - ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure) - ); - } else { - clearModelForIterator(detectorId, models, idIter, listener); - } - } else { - listener.onResponse(null); - } - } - /** * Trains and saves cold-start AD models. * @@ -579,13 +565,18 @@ private void maintenanceForIterator( logger.warn("Failed to finish maintenance for model id " + modelId, e); maintenanceForIterator(models, iter, listener); }); - T model = modelState.getModel(); - if (model instanceof ThresholdedRandomCutForest) { - checkpointDao.putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest) model, checkpointListener); - } else if (model instanceof ThresholdingModel) { - checkpointDao.putThresholdCheckpoint(modelId, (ThresholdingModel) model, checkpointListener); + Optional modelOptional = modelState.getModel(); + if (modelOptional.isPresent()) { + T model = modelOptional.get(); + if (model instanceof ThresholdedRandomCutForest) { + checkpointDao.putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest) model, checkpointListener); + } else if (model instanceof ThresholdingModel) { + checkpointDao.putThresholdCheckpoint(modelId, (ThresholdingModel) model, checkpointListener); + } else { + checkpointListener.onFailure(new IllegalArgumentException("Unexpected model type")); + } } else { - checkpointListener.onFailure(new IllegalArgumentException("Unexpected model type")); + maintenanceForIterator(models, iter, listener); } } else { maintenanceForIterator(models, iter, listener); @@ -657,17 +648,11 @@ public List getPreviewResults(double[][] dataPoints, int shi @Override public Map getModelSize(String detectorId) { Map res = new HashMap<>(); - forests - .entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) - .forEach(entry -> { - res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(entry.getValue().getModel())); - }); + res.putAll(forests.getModelSize(detectorId)); thresholds .entrySet() .stream() - .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId)) .forEach(entry -> { res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes()); }); @@ -683,8 +668,8 @@ public Map getModelSize(String detectorId) { public void getTotalUpdates(String modelId, String detectorId, ActionListener listener) { ModelState model = forests.get(modelId); if (model != null) { - if (model.getModel() != null && model.getModel().getForest() != null) { - listener.onResponse(model.getModel().getForest().getTotalUpdates()); + if (model.getModel().isPresent() && model.getModel().get().getForest() != null) { + listener.onResponse(model.getModel().get().getForest().getTotalUpdates()); } else { listener.onResponse(0L); } @@ -698,131 +683,13 @@ public void getTotalUpdates(String modelId, String detectorId, ActionListener modelState, - String modelId, - Entity entity, - int shingleSize - ) { - ThresholdingResult result = new ThresholdingResult(0, 0, 0); - if (modelState != null) { - EntityModel entityModel = modelState.getModel(); - - if (entityModel == null) { - entityModel = new EntityModel(entity, new ArrayDeque<>(), null); - modelState.setModel(entityModel); - } - - if (!entityModel.getTrcf().isPresent()) { - entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); - } - - if (entityModel.getTrcf().isPresent()) { - result = score(datapoint, modelId, modelState); - } else { - entityModel.addSample(datapoint); - } - } - return result; - } - - public ThresholdingResult score(double[] feature, String modelId, ModelState modelState) { - ThresholdingResult result = new ThresholdingResult(0, 0, 0); - EntityModel model = modelState.getModel(); - try { - if (model != null && model.getTrcf().isPresent()) { - ThresholdedRandomCutForest trcf = model.getTrcf().get(); - Optional.ofNullable(model.getSamples()).ifPresent(q -> { - q.stream().forEach(s -> trcf.process(s, 0)); - q.clear(); - }); - result = toResult(trcf.getForest(), trcf.process(feature, 0)); - } - } catch (Exception e) { - logger - .error( - new ParameterizedMessage( - "Fail to score for [{}]: model Id [{}], feature [{}]", - modelState.getModel().getEntity(), - modelId, - Arrays.toString(feature) - ), - e - ); - throw e; - } finally { - modelState.setLastUsedTime(clock.instant()); - } - return result; - } - - /** - * Instantiate an entity state out of checkpoint. Train models if there are - * enough samples. - * @param checkpoint Checkpoint loaded from index - * @param entity objects to access Entity attributes - * @param modelId Model Id - * @param detectorId Detector Id - * @param shingleSize Shingle size - * - * @return updated model state - * - */ - public ModelState processEntityCheckpoint( - Optional> checkpoint, - Entity entity, - String modelId, - String detectorId, - int shingleSize - ) { - // entity state to instantiate - ModelState modelState = new ModelState<>( - new EntityModel(entity, new ArrayDeque<>(), null), - modelId, - detectorId, - ModelType.ENTITY.getName(), - clock, - 0 - ); - - if (checkpoint.isPresent()) { - Entry modelToTime = checkpoint.get(); - EntityModel restoredModel = modelToTime.getKey(); - combineSamples(modelState.getModel(), restoredModel); - modelState.setModel(restoredModel); - modelState.setLastCheckpointTime(modelToTime.getValue()); - } - EntityModel model = modelState.getModel(); - if (model == null) { - model = new EntityModel(null, new ArrayDeque<>(), null); - modelState.setModel(model); - } - - if (!model.getTrcf().isPresent() && model.getSamples() != null && model.getSamples().size() >= rcfNumMinSamples) { - entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); - } - return modelState; - } - - private void combineSamples(EntityModel fromModel, EntityModel toModel) { - Queue samples = fromModel.getSamples(); - while (samples.peek() != null) { - toModel.addSample(samples.poll()); - } + @Override + protected ThresholdingResult createEmptyResult() { + return new ThresholdingResult(0, 0, 0); } - private ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { + @Override + protected ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { return new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), anomalyDescriptor.getDataConfidence(), diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java deleted file mode 100644 index 1044b84ce..000000000 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ /dev/null @@ -1,758 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; - -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.util.AbstractMap.SimpleImmutableEntry; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.Queue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.core.util.Throwables; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.caching.DoorKeeper; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.settings.ADEnabledSetting; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.CleanState; -import org.opensearch.timeseries.MaintenanceState; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.dataprocessor.Imputer; -import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.model.Config; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.settings.TimeSeriesSettings; -import org.opensearch.timeseries.util.ExceptionUtil; - -import com.amazon.randomcutforest.config.Precision; -import com.amazon.randomcutforest.config.TransformMethod; -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -/** - * Training models for HCAD detectors - * - */ -public class EntityColdStarter implements MaintenanceState, CleanState { - private static final Logger logger = LogManager.getLogger(EntityColdStarter.class); - private final Clock clock; - private final ThreadPool threadPool; - private final NodeStateManager nodeStateManager; - private final int rcfSampleSize; - private final int numberOfTrees; - private final double rcfTimeDecay; - private final int numMinSamples; - private final double thresholdMinPvalue; - private final int defaulStrideLength; - private final int defaultNumberOfSamples; - private final Imputer imputer; - private final SearchFeatureDao searchFeatureDao; - private Instant lastThrottledColdStartTime; - private final FeatureManager featureManager; - private int coolDownMinutes; - // A bloom filter checked before cold start to ensure we don't repeatedly - // retry cold start of the same model. - // keys are detector ids. - private Map doorKeepers; - private final Duration modelTtl; - private final CheckpointWriteWorker checkpointWriteQueue; - // make sure rcf use a specific random seed. Otherwise, we will use a random random (not a typo) seed. - // this is mainly used for testing to make sure the model we trained and the reference rcf produce - // the same results - private final long rcfSeed; - private final int maxRoundofColdStart; - private final double initialAcceptFraction; - - /** - * Constructor - * - * @param clock UTC clock - * @param threadPool Accessor to different threadpools - * @param nodeStateManager Storing node state - * @param rcfSampleSize The sample size used by stream samplers in this forest - * @param numberOfTrees The number of trees in this forest. - * @param rcfTimeDecay rcf samples time decay constant - * @param numMinSamples The number of points required by stream samplers before - * results are returned. - * @param defaultSampleStride default sample distances measured in detector intervals. - * @param defaultTrainSamples Default train samples to collect. - * @param imputer Used to generate data points between samples. - * @param searchFeatureDao Used to issue ES queries. - * @param thresholdMinPvalue min P-value for thresholding - * @param featureManager Used to create features for models. - * @param settings ES settings accessor - * @param modelTtl time-to-live before last access time of the cold start cache. - * We have a cache to record entities that have run cold starts to avoid - * repeated unsuccessful cold start. - * @param checkpointWriteQueue queue to insert model checkpoints - * @param rcfSeed rcf random seed - * @param maxRoundofColdStart max number of rounds of cold start - */ - public EntityColdStarter( - Clock clock, - ThreadPool threadPool, - NodeStateManager nodeStateManager, - int rcfSampleSize, - int numberOfTrees, - double rcfTimeDecay, - int numMinSamples, - int defaultSampleStride, - int defaultTrainSamples, - Imputer imputer, - SearchFeatureDao searchFeatureDao, - double thresholdMinPvalue, - FeatureManager featureManager, - Settings settings, - Duration modelTtl, - CheckpointWriteWorker checkpointWriteQueue, - long rcfSeed, - int maxRoundofColdStart - ) { - this.clock = clock; - this.lastThrottledColdStartTime = Instant.MIN; - this.threadPool = threadPool; - this.nodeStateManager = nodeStateManager; - this.rcfSampleSize = rcfSampleSize; - this.numberOfTrees = numberOfTrees; - this.rcfTimeDecay = rcfTimeDecay; - this.numMinSamples = numMinSamples; - this.defaulStrideLength = defaultSampleStride; - this.defaultNumberOfSamples = defaultTrainSamples; - this.imputer = imputer; - this.searchFeatureDao = searchFeatureDao; - this.thresholdMinPvalue = thresholdMinPvalue; - this.featureManager = featureManager; - this.coolDownMinutes = (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()); - this.doorKeepers = new ConcurrentHashMap<>(); - this.modelTtl = modelTtl; - this.checkpointWriteQueue = checkpointWriteQueue; - this.rcfSeed = rcfSeed; - this.maxRoundofColdStart = maxRoundofColdStart; - this.initialAcceptFraction = numMinSamples * 1.0d / rcfSampleSize; - } - - public EntityColdStarter( - Clock clock, - ThreadPool threadPool, - NodeStateManager nodeStateManager, - int rcfSampleSize, - int numberOfTrees, - double rcfTimeDecay, - int numMinSamples, - int maxSampleStride, - int maxTrainSamples, - Imputer imputer, - SearchFeatureDao searchFeatureDao, - double thresholdMinPvalue, - FeatureManager featureManager, - Settings settings, - Duration modelTtl, - CheckpointWriteWorker checkpointWriteQueue, - int maxRoundofColdStart - ) { - this( - clock, - threadPool, - nodeStateManager, - rcfSampleSize, - numberOfTrees, - rcfTimeDecay, - numMinSamples, - maxSampleStride, - maxTrainSamples, - imputer, - searchFeatureDao, - thresholdMinPvalue, - featureManager, - settings, - modelTtl, - checkpointWriteQueue, - -1, - maxRoundofColdStart - ); - } - - /** - * Training model for an entity - * @param modelId model Id corresponding to the entity - * @param entity the entity's information - * @param detectorId the detector Id corresponding to the entity - * @param modelState model state associated with the entity - * @param listener call back to call after cold start - */ - private void coldStart( - String modelId, - Entity entity, - String detectorId, - ModelState modelState, - AnomalyDetector detector, - ActionListener listener - ) { - logger.debug("Trigger cold start for {}", modelId); - - if (modelState == null || entity == null) { - listener - .onFailure( - new IllegalArgumentException( - String - .format( - Locale.ROOT, - "Cannot have empty model state or entity: model state [%b], entity [%b]", - modelState == null, - entity == null - ) - ) - ); - return; - } - - if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { - listener.onResponse(null); - return; - } - - boolean earlyExit = true; - try { - DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(detectorId, id -> { - // reset every 60 intervals - return new DoorKeeper( - TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, - TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, - detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), - clock - ); - }); - - // Won't retry cold start within 60 intervals for an entity - if (doorKeeper.mightContain(modelId)) { - return; - } - - doorKeeper.put(modelId); - - ActionListener>> coldStartCallBack = ActionListener.wrap(trainingData -> { - try { - if (trainingData.isPresent()) { - List dataPoints = trainingData.get(); - extractTrainSamples(dataPoints, modelId, modelState); - Queue samples = modelState.getModel().getSamples(); - // only train models if we have enough samples - if (samples.size() >= numMinSamples) { - // The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by - // multiple places so I want to make the saving model implicit just in case I forgot. - trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); - logger.info("Succeeded in training entity: {}", modelId); - } else { - // save to checkpoint - checkpointWriteQueue.write(modelState, true, RequestPriority.MEDIUM); - logger.info("Not enough data to train entity: {}, currently we have {}", modelId, samples.size()); - } - } else { - logger.info("Cannot get training data for {}", modelId); - } - listener.onResponse(null); - } catch (Exception e) { - listener.onFailure(e); - } - }, exception -> { - try { - logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); - Throwable cause = Throwables.getRootCause(exception); - if (ExceptionUtil.isOverloaded(cause)) { - logger.error("too many requests"); - lastThrottledColdStartTime = Instant.now(); - } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) { - // e.g., cannot find anomaly detector - nodeStateManager.setException(detectorId, exception); - } else { - nodeStateManager.setException(detectorId, new TimeSeriesException(detectorId, cause)); - } - listener.onFailure(exception); - } catch (Exception e) { - listener.onFailure(e); - } - }); - - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> getEntityColdStartData( - detectorId, - entity, - new ThreadedActionListener<>( - logger, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - coldStartCallBack, - false - ) - ) - ); - earlyExit = false; - } finally { - if (earlyExit) { - listener.onResponse(null); - } - } - } - - /** - * Train model using given data points and save the trained model. - * - * @param dataPoints Queue of continuous data points, in ascending order of timestamps - * @param entity Entity instance - * @param entityState Entity state associated with the model Id - */ - private void trainModelFromDataSegments( - Queue dataPoints, - Entity entity, - ModelState entityState, - int shingleSize - ) { - if (dataPoints == null || dataPoints.size() == 0) { - throw new IllegalArgumentException("Data points must not be empty."); - } - - double[] firstPoint = dataPoints.peek(); - if (firstPoint == null || firstPoint.length == 0) { - throw new IllegalArgumentException("Data points must not be empty."); - } - int dimensions = firstPoint.length * shingleSize; - ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest - .builder() - .dimensions(dimensions) - .sampleSize(rcfSampleSize) - .numberOfTrees(numberOfTrees) - .timeDecay(rcfTimeDecay) - .outputAfter(numMinSamples) - .initialAcceptFraction(initialAcceptFraction) - .parallelExecutionEnabled(false) - .compact(true) - .precision(Precision.FLOAT_32) - .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) - // same with dimension for opportunistic memory saving - // Usually, we use it as shingleSize(dimension). When a new point comes in, we will - // look at the point store if there is any overlapping. Say the previously-stored - // vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize - // overlapping x3, x4, and only store x5, x6. - .shingleSize(shingleSize) - .internalShinglingEnabled(true) - .anomalyRate(1 - this.thresholdMinPvalue) - .transformMethod(TransformMethod.NORMALIZE) - .alertOnce(true) - .autoAdjust(true); - - if (rcfSeed > 0) { - rcfBuilder.randomSeed(rcfSeed); - } - ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest(rcfBuilder); - while (!dataPoints.isEmpty()) { - trcf.process(dataPoints.poll(), 0); - } - EntityModel model = entityState.getModel(); - if (model == null) { - model = new EntityModel(entity, new ArrayDeque<>(), null); - } - model.setTrcf(trcf); - - entityState.setLastUsedTime(clock.instant()); - - // save to checkpoint - checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM); - } - - /** - * Get training data for an entity. - * - * We first note the maximum and minimum timestamp, and sample at most 24 points - * (with 60 points apart between two neighboring samples) between those minimum - * and maximum timestamps. Samples can be missing. We only interpolate points - * between present neighboring samples. We then transform samples and interpolate - * points to shingles. Finally, full shingles will be used for cold start. - * - * @param detectorId detector Id - * @param entity the entity's information - * @param listener listener to return training data - */ - private void getEntityColdStartData(String detectorId, Entity entity, ActionListener>> listener) { - ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { - if (!detectorOp.isPresent()) { - listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); - return; - } - List coldStartData = new ArrayList<>(); - AnomalyDetector detector = (AnomalyDetector) detectorOp.get(); - - ActionListener> minTimeListener = ActionListener.wrap(earliest -> { - if (earliest.isPresent()) { - long startTimeMs = earliest.get().longValue(); - - // End time uses milliseconds as start time is assumed to be in milliseconds. - // Opensearch uses a set of preconfigured formats to recognize and parse these - // strings into a long value - // representing milliseconds-since-the-epoch in UTC. - // More on https://tinyurl.com/wub4fk92 - - long endTimeMs = clock.millis(); - Pair params = selectRangeParam(detector); - int stride = params.getLeft(); - int numberOfSamples = params.getRight(); - - // we start with round 0 - getFeatures(listener, 0, coldStartData, detector, entity, stride, numberOfSamples, startTimeMs, endTimeMs); - } else { - listener.onResponse(Optional.empty()); - } - }, listener::onFailure); - - searchFeatureDao - .getMinDataTime( - detector, - Optional.ofNullable(entity), - AnalysisType.AD, - new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, minTimeListener, false) - ); - - }, listener::onFailure); - - nodeStateManager - .getConfig( - detectorId, - AnalysisType.AD, - new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, getDetectorListener, false) - ); - } - - private void getFeatures( - ActionListener>> listener, - int round, - List lastRoundColdStartData, - AnomalyDetector detector, - Entity entity, - int stride, - int numberOfSamples, - long startTimeMs, - long endTimeMs - ) { - if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < detector.getIntervalInMilliseconds()) { - listener.onResponse(Optional.of(lastRoundColdStartData)); - return; - } - - // create ranges in desending order, we will reorder it in ascending order - // in Opensearch's response - List> sampleRanges = getTrainSampleRanges(detector, startTimeMs, endTimeMs, stride, numberOfSamples); - - if (sampleRanges.isEmpty()) { - listener.onResponse(Optional.of(lastRoundColdStartData)); - return; - } - - ActionListener>> getFeaturelistener = ActionListener.wrap(featureSamples -> { - // storing lastSample = null; - List currentRoundColdStartData = new ArrayList<>(); - - // featuresSamples are in ascending order of time. - for (int i = 0; i < featureSamples.size(); i++) { - Optional featuresOptional = featureSamples.get(i); - if (featuresOptional.isPresent()) { - // we only need the most recent two samples - // For the missing samples we use linear interpolation as well. - // Denote the Samples S0, S1, ... as samples in reverse order of time. - // Each [Si​,Si−1​]corresponds to strideLength * detector interval. - // If we got samples for S0, S1, S4 (both S2 and S3 are missing), then - // we interpolate the [S4,S1] into 3*strideLength pieces. - if (lastSample != null) { - // right sample has index i and feature featuresOptional.get() - int numInterpolants = (i - lastSample.getLeft()) * stride + 1; - double[][] points = featureManager - .transpose( - imputer - .impute( - featureManager.transpose(new double[][] { lastSample.getRight(), featuresOptional.get() }), - numInterpolants - ) - ); - // the last point will be included in the next iteration or we process - // it in the end. We don't want to repeatedly include the samples twice. - currentRoundColdStartData.add(Arrays.copyOfRange(points, 0, points.length - 1)); - } - lastSample = Pair.of(i, featuresOptional.get()); - } - } - - if (lastSample != null) { - currentRoundColdStartData.add(new double[][] { lastSample.getRight() }); - } - if (lastRoundColdStartData.size() > 0) { - currentRoundColdStartData.addAll(lastRoundColdStartData); - } - - // If the first round of probe provides (32+shingleSize) points (note that if S0 is - // missing or all Si​ for some i > N is missing then we would miss a lot of points. - // Otherwise we can issue another round of query — if there is any sample in the - // second round then we would have 32 + shingleSize points. If there is no sample - // in the second round then we should wait for real data. - if (calculateColdStartDataSize(currentRoundColdStartData) >= detector.getShingleSize() + numMinSamples - || round + 1 >= maxRoundofColdStart) { - listener.onResponse(Optional.of(currentRoundColdStartData)); - } else { - // the last sample's start time is the endTimeMs of next round of probe. - long lastSampleStartTime = sampleRanges.get(sampleRanges.size() - 1).getKey(); - getFeatures( - listener, - round + 1, - currentRoundColdStartData, - detector, - entity, - stride, - numberOfSamples, - startTimeMs, - lastSampleStartTime - ); - } - }, listener::onFailure); - - try { - searchFeatureDao - .getColdStartSamplesForPeriods( - detector, - sampleRanges, - Optional.ofNullable(entity), - // Accept empty bucket. - // 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 - // is meaningless in some case, but 0 is also meaningful in some cases. It may be that the query defining the - // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do - // that, we will have issues with legitimate interpretations of 0. - true, - AnalysisType.AD, - new ThreadedActionListener<>( - logger, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - getFeaturelistener, - false - ) - ); - } catch (Exception e) { - listener.onFailure(e); - } - } - - private int calculateColdStartDataSize(List coldStartData) { - int size = 0; - for (int i = 0; i < coldStartData.size(); i++) { - size += coldStartData.get(i).length; - } - return size; - } - - /** - * Select strideLength and numberOfSamples, where stride is the number of intervals - * between two samples and trainSamples is training samples to fetch. If we disable - * interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples; - * - * Algorithm: - * - * delta is the length of the detector interval in minutes. - * - * 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 - * and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32 - * points — which is only good. This step tries to match data with hourly pattern. - * 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1. - * This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes - *(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits -- - * we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern. - * That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern. - * Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. - * @return the chosen strideLength and numberOfSamples - */ - private Pair selectRangeParam(AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); - if (ADEnabledSetting.isInterpolationInColdStartEnabled()) { - long delta = detector.getIntervalInMinutes(); - - int strideLength = defaulStrideLength; - int numberOfSamples = defaultNumberOfSamples; - if (delta <= 30 && 60 % delta == 0) { - strideLength = (int) (60 / delta); - numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; - } else { - strideLength = 1; - numberOfSamples = shingleSize + numMinSamples; - } - return Pair.of(strideLength, numberOfSamples); - } else { - return Pair.of(1, shingleSize + numMinSamples); - } - - } - - /** - * Get train samples within a time range. - * - * @param detector accessor to detector config - * @param startMilli range start - * @param endMilli range end - * @param stride the number of intervals between two samples - * @param numberOfSamples maximum training samples to fetch - * @return list of sample time ranges - */ - private List> getTrainSampleRanges( - AnomalyDetector detector, - long startMilli, - long endMilli, - int stride, - int numberOfSamples - ) { - long bucketSize = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMillis(); - int numBuckets = (int) Math.floor((endMilli - startMilli) / (double) bucketSize); - // adjust if numStrides is more than the max samples - int numStrides = Math.min((int) Math.floor(numBuckets / (double) stride), numberOfSamples); - List> sampleRanges = Stream - .iterate(endMilli, i -> i - stride * bucketSize) - .limit(numStrides) - .map(time -> new SimpleImmutableEntry<>(time - bucketSize, time)) - .collect(Collectors.toList()); - return sampleRanges; - } - - /** - * Train models for the given entity - * @param entity The entity info - * @param detectorId Detector Id - * @param modelState Model state associated with the entity - * @param listener callback before the method returns whenever EntityColdStarter - * finishes training or encounters exceptions. The listener helps notify the - * cold start queue to pull another request (if any) to execute. - */ - public void trainModel(Entity entity, String detectorId, ModelState modelState, ActionListener listener) { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - logger.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); - listener.onFailure(new TimeSeriesException(detectorId, "fail to find detector")); - return; - } - - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - - Queue samples = modelState.getModel().getSamples(); - String modelId = modelState.getModelId(); - - if (samples.size() < this.numMinSamples) { - // we cannot get last RCF score since cold start happens asynchronously - coldStart(modelId, entity, detectorId, modelState, detector, listener); - } else { - try { - trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); - listener.onResponse(null); - } catch (Exception e) { - listener.onFailure(e); - } - } - - }, listener::onFailure)); - } - - public void trainModelFromExistingSamples(ModelState modelState, int shingleSize) { - if (modelState == null || modelState.getModel() == null || modelState.getModel().getSamples() == null) { - return; - } - - EntityModel model = modelState.getModel(); - Queue samples = model.getSamples(); - if (samples.size() >= this.numMinSamples) { - try { - trainModelFromDataSegments(samples, model.getEntity().orElse(null), modelState, shingleSize); - } catch (Exception e) { - // e.g., exception from rcf. We can do nothing except logging the error - // We won't retry training for the same entity in the cooldown period - // (60 detector intervals). - logger.error("Unexpected training error", e); - } - - } - } - - /** - * Extract training data and put them into ModelState - * - * @param coldstartDatapoints training data generated from cold start - * @param modelId model Id - * @param modelState entity State - */ - private void extractTrainSamples(List coldstartDatapoints, String modelId, ModelState modelState) { - if (coldstartDatapoints == null || coldstartDatapoints.size() == 0 || modelState == null) { - return; - } - - EntityModel model = modelState.getModel(); - if (model == null) { - model = new EntityModel(null, new ArrayDeque<>(), null); - modelState.setModel(model); - } - - Queue newSamples = new ArrayDeque<>(); - for (double[][] consecutivePoints : coldstartDatapoints) { - for (int i = 0; i < consecutivePoints.length; i++) { - newSamples.add(consecutivePoints[i]); - } - } - - model.setSamples(newSamples); - } - - @Override - public void maintenance() { - doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { - String detectorId = doorKeeperEntry.getKey(); - DoorKeeper doorKeeper = doorKeeperEntry.getValue(); - if (doorKeeper.expired(modelTtl)) { - doorKeepers.remove(detectorId); - } else { - doorKeeper.maintenance(); - } - }); - } - - @Override - public void clear(String detectorId) { - doorKeepers.remove(detectorId); - } -} diff --git a/src/main/java/org/opensearch/ad/ml/EntityModel.java b/src/main/java/org/opensearch/ad/ml/EntityModel.java deleted file mode 100644 index 348ad8c6e..000000000 --- a/src/main/java/org/opensearch/ad/ml/EntityModel.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import java.util.ArrayDeque; -import java.util.Optional; -import java.util.Queue; - -import org.opensearch.timeseries.model.Entity; - -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -public class EntityModel { - private Entity entity; - // TODO: sample should record timestamp - private Queue samples; - - private ThresholdedRandomCutForest trcf; - - /** - * Constructor with TRCF. - * - * @param entity entity if any - * @param samples samples with the model - * @param trcf thresholded rcf model - */ - public EntityModel(Entity entity, Queue samples, ThresholdedRandomCutForest trcf) { - this.entity = entity; - this.samples = samples; - this.trcf = trcf; - } - - /** - * In old checkpoint mapping, we don't have entity. It's fine we are missing - * entity as it is mostly used for debugging. - * @return entity - */ - public Optional getEntity() { - return Optional.ofNullable(entity); - } - - public Queue getSamples() { - return this.samples; - } - - public void setSamples(Queue samples) { - this.samples = samples; - } - - public void addSample(double[] sample) { - if (this.samples == null) { - this.samples = new ArrayDeque<>(); - } - if (sample != null && sample.length != 0) { - this.samples.add(sample); - } - } - - /** - * Sets an trcf model. - * - * @param trcf an trcf model - */ - public void setTrcf(ThresholdedRandomCutForest trcf) { - this.trcf = trcf; - } - - /** - * Returns optional trcf model. - * - * @return the trcf model or empty - */ - public Optional getTrcf() { - return Optional.ofNullable(this.trcf); - } - - public void clear() { - if (samples != null) { - samples.clear(); - } - trcf = null; - } -} diff --git a/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java deleted file mode 100644 index 2380173b0..000000000 --- a/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import java.util.concurrent.ConcurrentHashMap; - -import org.opensearch.timeseries.MemoryTracker; -import org.opensearch.timeseries.MemoryTracker.Origin; - -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -/** - * A customized ConcurrentHashMap that can automatically consume and release memory. - * This enables minimum change to our single-entity code as we just have to replace - * the map implementation. - * - * Note: this is mainly used for single-entity detectors. - */ -public class TRCFMemoryAwareConcurrentHashmap extends ConcurrentHashMap> { - private final MemoryTracker memoryTracker; - - public TRCFMemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) { - this.memoryTracker = memoryTracker; - } - - @Override - public ModelState remove(Object key) { - ModelState deletedModelState = super.remove(key); - if (deletedModelState != null && deletedModelState.getModel() != null) { - long memoryToRelease = memoryTracker.estimateTRCFModelSize(deletedModelState.getModel()); - memoryTracker.releaseMemory(memoryToRelease, true, Origin.REAL_TIME_DETECTOR); - } - return deletedModelState; - } - - @Override - public ModelState put(K key, ModelState value) { - ModelState previousAssociatedState = super.put(key, value); - if (value != null && value.getModel() != null) { - long memoryToConsume = memoryTracker.estimateTRCFModelSize(value.getModel()); - memoryTracker.consumeMemory(memoryToConsume, true, Origin.REAL_TIME_DETECTOR); - } - return previousAssociatedState; - } -} diff --git a/src/main/java/org/opensearch/ad/model/ADTask.java b/src/main/java/org/opensearch/ad/model/ADTask.java index 93566a0f0..18a905e2f 100644 --- a/src/main/java/org/opensearch/ad/model/ADTask.java +++ b/src/main/java/org/opensearch/ad/model/ADTask.java @@ -141,7 +141,7 @@ public static Builder builder() { } @Override - public boolean isEntityTask() { + public boolean isHistoricalEntityTask() { return ADTaskType.HISTORICAL_HC_ENTITY.name().equals(taskType); } @@ -337,7 +337,8 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept detector.getCategoryFields(), detector.getUser(), detector.getCustomResultIndex(), - detector.getImputationOption() + detector.getImputationOption(), + detector.getTransformDecay() ); return new Builder() .taskId(parsedTaskId) diff --git a/src/main/java/org/opensearch/ad/model/ADTaskProfile.java b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java index cd6eaeaa0..0a31d5d95 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskProfile.java +++ b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java @@ -21,44 +21,30 @@ import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.model.EntityTaskProfile; /** * One anomaly detection task means one detector starts to run until stopped. */ -public class ADTaskProfile implements ToXContentObject, Writeable { +public class ADTaskProfile extends TaskProfile { public static final String AD_TASK_FIELD = "ad_task"; - public static final String SHINGLE_SIZE_FIELD = "shingle_size"; - public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; public static final String THRESHOLD_MODEL_TRAINED_FIELD = "threshold_model_trained"; public static final String THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD = "threshold_model_training_data_size"; - public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; - public static final String NODE_ID_FIELD = "node_id"; - public static final String TASK_ID_FIELD = "task_id"; - public static final String AD_TASK_TYPE_FIELD = "task_type"; public static final String DETECTOR_TASK_SLOTS_FIELD = "detector_task_slots"; public static final String TOTAL_ENTITIES_INITED_FIELD = "total_entities_inited"; public static final String TOTAL_ENTITIES_COUNT_FIELD = "total_entities_count"; public static final String PENDING_ENTITIES_COUNT_FIELD = "pending_entities_count"; public static final String RUNNING_ENTITIES_COUNT_FIELD = "running_entities_count"; public static final String RUNNING_ENTITIES_FIELD = "running_entities"; - public static final String ENTITY_TASK_PROFILE_FIELD = "entity_task_profiles"; public static final String LATEST_HC_TASK_RUN_TIME_FIELD = "latest_hc_task_run_time"; - private ADTask adTask; - private Integer shingleSize; - private Long rcfTotalUpdates; private Boolean thresholdModelTrained; private Integer thresholdModelTrainingDataSize; - private Long modelSizeInBytes; - private String nodeId; - private String taskId; - private String adTaskType; private Integer detectorTaskSlots; private Boolean totalEntitiesInited; private Integer totalEntitiesCount; @@ -66,15 +52,14 @@ public class ADTaskProfile implements ToXContentObject, Writeable { private Integer runningEntitiesCount; private List runningEntities; private Long latestHCTaskRunTime; - - private List entityTaskProfiles; + protected List entityTaskProfiles; public ADTaskProfile() { } public ADTaskProfile(ADTask adTask) { - this.adTask = adTask; + super(adTask); } public ADTaskProfile( @@ -86,13 +71,9 @@ public ADTaskProfile( long modelSizeInBytes, String nodeId ) { - this.taskId = taskId; - this.shingleSize = shingleSize; - this.rcfTotalUpdates = rcfTotalUpdates; + super(taskId, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId); this.thresholdModelTrained = thresholdModelTrained; this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; - this.modelSizeInBytes = modelSizeInBytes; - this.nodeId = nodeId; } public ADTaskProfile( @@ -113,15 +94,9 @@ public ADTaskProfile( List runningEntities, Long latestHCTaskRunTime ) { - this.adTask = adTask; - this.shingleSize = shingleSize; - this.rcfTotalUpdates = rcfTotalUpdates; + super(adTask, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, adTaskType); this.thresholdModelTrained = thresholdModelTrained; this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; - this.modelSizeInBytes = modelSizeInBytes; - this.nodeId = nodeId; - this.taskId = taskId; - this.adTaskType = adTaskType; this.detectorTaskSlots = detectorTaskSlots; this.totalEntitiesInited = totalEntitiesInited; this.totalEntitiesCount = totalEntitiesCount; @@ -133,9 +108,9 @@ public ADTaskProfile( public ADTaskProfile(StreamInput input) throws IOException { if (input.readBoolean()) { - this.adTask = new ADTask(input); + this.task = new ADTask(input); } else { - this.adTask = null; + this.task = null; } this.shingleSize = input.readOptionalInt(); this.rcfTotalUpdates = input.readOptionalLong(); @@ -145,7 +120,7 @@ public ADTaskProfile(StreamInput input) throws IOException { this.nodeId = input.readOptionalString(); if (input.available() > 0) { this.taskId = input.readOptionalString(); - this.adTaskType = input.readOptionalString(); + this.taskType = input.readOptionalString(); this.detectorTaskSlots = input.readOptionalInt(); this.totalEntitiesInited = input.readOptionalBoolean(); this.totalEntitiesCount = input.readOptionalInt(); @@ -155,7 +130,7 @@ public ADTaskProfile(StreamInput input) throws IOException { this.runningEntities = input.readStringList(); } if (input.readBoolean()) { - this.entityTaskProfiles = input.readList(ADEntityTaskProfile::new); + this.entityTaskProfiles = input.readList(EntityTaskProfile::new); } this.latestHCTaskRunTime = input.readOptionalLong(); } @@ -167,9 +142,9 @@ public void writeTo(StreamOutput out) throws IOException { } public void writeTo(StreamOutput out, Version adVersion) throws IOException { - if (adTask != null) { + if (task != null) { out.writeBoolean(true); - adTask.writeTo(out); + task.writeTo(out); } else { out.writeBoolean(false); } @@ -182,7 +157,7 @@ public void writeTo(StreamOutput out, Version adVersion) throws IOException { out.writeOptionalString(nodeId); if (adVersion != null) { out.writeOptionalString(taskId); - out.writeOptionalString(adTaskType); + out.writeOptionalString(taskType); out.writeOptionalInt(detectorTaskSlots); out.writeOptionalBoolean(totalEntitiesInited); out.writeOptionalInt(totalEntitiesCount); @@ -207,33 +182,13 @@ public void writeTo(StreamOutput out, Version adVersion) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject(); - if (adTask != null) { - xContentBuilder.field(AD_TASK_FIELD, adTask); - } - if (shingleSize != null) { - xContentBuilder.field(SHINGLE_SIZE_FIELD, shingleSize); - } - if (rcfTotalUpdates != null) { - xContentBuilder.field(RCF_TOTAL_UPDATES_FIELD, rcfTotalUpdates); - } + super.toXContent(xContentBuilder); if (thresholdModelTrained != null) { xContentBuilder.field(THRESHOLD_MODEL_TRAINED_FIELD, thresholdModelTrained); } if (thresholdModelTrainingDataSize != null) { xContentBuilder.field(THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD, thresholdModelTrainingDataSize); } - if (modelSizeInBytes != null) { - xContentBuilder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); - } - if (nodeId != null) { - xContentBuilder.field(NODE_ID_FIELD, nodeId); - } - if (taskId != null) { - xContentBuilder.field(TASK_ID_FIELD, taskId); - } - if (adTaskType != null) { - xContentBuilder.field(AD_TASK_TYPE_FIELD, adTaskType); - } if (detectorTaskSlots != null) { xContentBuilder.field(DETECTOR_TASK_SLOTS_FIELD, detectorTaskSlots); } @@ -252,12 +207,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (runningEntities != null) { xContentBuilder.field(RUNNING_ENTITIES_FIELD, runningEntities); } - if (entityTaskProfiles != null && entityTaskProfiles.size() > 0) { - xContentBuilder.field(ENTITY_TASK_PROFILE_FIELD, entityTaskProfiles.toArray()); - } if (latestHCTaskRunTime != null) { xContentBuilder.field(LATEST_HC_TASK_RUN_TIME_FIELD, latestHCTaskRunTime); } + if (entityTaskProfiles != null && entityTaskProfiles.size() > 0) { + xContentBuilder.field(ENTITY_TASK_PROFILE_FIELD, entityTaskProfiles.toArray()); + } return xContentBuilder.endObject(); } @@ -277,7 +232,7 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { Integer pendingEntitiesCount = null; Integer runningEntitiesCount = null; List runningEntities = null; - List entityTaskProfiles = null; + List entityTaskProfiles = null; Long latestHCTaskRunTime = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -310,7 +265,7 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { case TASK_ID_FIELD: taskId = parser.text(); break; - case AD_TASK_TYPE_FIELD: + case TASK_TYPE_FIELD: taskType = parser.text(); break; case DETECTOR_TASK_SLOTS_FIELD: @@ -339,7 +294,7 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { entityTaskProfiles = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - entityTaskProfiles.add(ADEntityTaskProfile.parse(parser)); + entityTaskProfiles.add(EntityTaskProfile.parse(parser)); } break; case LATEST_HC_TASK_RUN_TIME_FIELD: @@ -370,30 +325,6 @@ public static ADTaskProfile parse(XContentParser parser) throws IOException { ); } - public ADTask getAdTask() { - return adTask; - } - - public void setAdTask(ADTask adTask) { - this.adTask = adTask; - } - - public Integer getShingleSize() { - return shingleSize; - } - - public void setShingleSize(Integer shingleSize) { - this.shingleSize = shingleSize; - } - - public Long getRcfTotalUpdates() { - return rcfTotalUpdates; - } - - public void setRcfTotalUpdates(Long rcfTotalUpdates) { - this.rcfTotalUpdates = rcfTotalUpdates; - } - public Boolean getThresholdModelTrained() { return thresholdModelTrained; } @@ -410,38 +341,6 @@ public void setThresholdModelTrainingDataSize(Integer thresholdModelTrainingData this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; } - public Long getModelSizeInBytes() { - return modelSizeInBytes; - } - - public void setModelSizeInBytes(Long modelSizeInBytes) { - this.modelSizeInBytes = modelSizeInBytes; - } - - public String getNodeId() { - return nodeId; - } - - public void setNodeId(String nodeId) { - this.nodeId = nodeId; - } - - public String getTaskId() { - return taskId; - } - - public void setTaskId(String taskId) { - this.taskId = taskId; - } - - public String getAdTaskType() { - return adTaskType; - } - - public void setAdTaskType(String adTaskType) { - this.adTaskType = adTaskType; - } - public boolean getTotalEntitiesInited() { return totalEntitiesInited != null && totalEntitiesInited.booleanValue(); } @@ -498,11 +397,11 @@ public void setRunningEntities(List runningEntities) { this.runningEntities = runningEntities; } - public List getEntityTaskProfiles() { + public List getEntityTaskProfiles() { return entityTaskProfiles; } - public void setEntityTaskProfiles(List entityTaskProfiles) { + public void setEntityTaskProfiles(List entityTaskProfiles) { this.entityTaskProfiles = entityTaskProfiles; } @@ -514,15 +413,9 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; ADTaskProfile that = (ADTaskProfile) o; - return Objects.equals(adTask, that.adTask) - && Objects.equals(shingleSize, that.shingleSize) - && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) + return super.equals(o) && Objects.equals(thresholdModelTrained, that.thresholdModelTrained) && Objects.equals(thresholdModelTrainingDataSize, that.thresholdModelTrainingDataSize) - && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) - && Objects.equals(nodeId, that.nodeId) - && Objects.equals(taskId, that.taskId) - && Objects.equals(adTaskType, that.adTaskType) && Objects.equals(detectorTaskSlots, that.detectorTaskSlots) && Objects.equals(totalEntitiesInited, that.totalEntitiesInited) && Objects.equals(totalEntitiesCount, that.totalEntitiesCount) @@ -536,17 +429,11 @@ public boolean equals(Object o) { @Generated @Override public int hashCode() { - return Objects + int hash = super.hashCode(); + hash = 89 * hash + Objects .hash( - adTask, - shingleSize, - rcfTotalUpdates, thresholdModelTrained, thresholdModelTrainingDataSize, - modelSizeInBytes, - nodeId, - taskId, - adTaskType, detectorTaskSlots, totalEntitiesInited, totalEntitiesCount, @@ -554,15 +441,17 @@ public int hashCode() { runningEntitiesCount, runningEntities, entityTaskProfiles, - latestHCTaskRunTime + latestHCTaskRunTime, + entityTaskProfiles ); + return hash; } @Override public String toString() { return "ADTaskProfile{" + "adTask=" - + adTask + + task + ", shingleSize=" + shingleSize + ", rcfTotalUpdates=" @@ -580,7 +469,7 @@ public String toString() { + taskId + '\'' + ", adTaskType='" - + adTaskType + + taskType + '\'' + ", detectorTaskSlots=" + detectorTaskSlots @@ -600,4 +489,9 @@ public String toString() { + entityTaskProfiles + '}'; } + + @Override + protected String getTaskFieldName() { + return AD_TASK_FIELD; + } } diff --git a/src/main/java/org/opensearch/ad/model/ADTaskType.java b/src/main/java/org/opensearch/ad/model/ADTaskType.java index d235bad7e..1781af555 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskType.java +++ b/src/main/java/org/opensearch/ad/model/ADTaskType.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; +// enum names need to start with REALTIME or HISTORICAL we use prefix in TaskManager to check if a task is of certain type (e.g., historical) public enum ADTaskType implements TaskType { @Deprecated HISTORICAL, @@ -29,9 +30,9 @@ public enum ADTaskType implements TaskType { HISTORICAL_HC_ENTITY; public static List HISTORICAL_DETECTOR_TASK_TYPES = ImmutableList - .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL); + .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.REALTIME_SINGLE_ENTITY, ADTaskType.HISTORICAL); public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList - .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL_HC_ENTITY, ADTaskType.HISTORICAL); + .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.REALTIME_SINGLE_ENTITY, ADTaskType.HISTORICAL_HC_ENTITY, ADTaskType.HISTORICAL); public static List REALTIME_TASK_TYPES = ImmutableList .of(ADTaskType.REALTIME_SINGLE_ENTITY, ADTaskType.REALTIME_HC_DETECTOR); public static List ALL_DETECTOR_TASK_TYPES = ImmutableList diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index aa86fa842..115675d90 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -124,7 +124,8 @@ public AnomalyDetector( List categoryFields, User user, String resultIndex, - ImputationOption imputationOption + ImputationOption imputationOption, + Double transformDecay ) { super( detectorId, @@ -144,7 +145,8 @@ public AnomalyDetector( user, resultIndex, detectionInterval, - imputationOption + imputationOption, + transformDecay ); checkAndThrowValidationErrors(ValidationAspect.DETECTOR); @@ -211,6 +213,7 @@ public AnomalyDetector(StreamInput input) throws IOException { this.imputationOption = null; } this.imputer = createImputer(); + this.transformDecay = input.readDouble(); } public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @@ -264,6 +267,7 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } + output.writeDouble(transformDecay); } @Override @@ -350,6 +354,7 @@ public static AnomalyDetector parse( List categoryField = null; ImputationOption imputationOption = null; + Double transformDecay = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -460,6 +465,9 @@ public static AnomalyDetector parse( case IMPUTATION_OPTION_FIELD: imputationOption = ImputationOption.parse(parser); break; + case TRANSFORM_DECAY_FIELD: + transformDecay = parser.doubleValue(); + break; default: parser.skipChildren(); break; @@ -483,7 +491,8 @@ public static AnomalyDetector parse( categoryField, user, resultIndex, - imputationOption + imputationOption, + transformDecay ); detector.setDetectionDateRange(detectionDateRange); return detector; diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java b/src/main/java/org/opensearch/ad/model/AnomalyResult.java index 4ee4e0ee7..bdfe4eb3c 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java @@ -182,6 +182,16 @@ So if we detect anomaly late, we get the baseDimension values from the past (cur // rcf score threshold at the time of writing a result private final Double threshold; protected final Double confidence; + /* + * model id for easy aggregations of entities. The front end needs to query + * for entities ordered by the descending/ascending order of feature values. + * After supporting multi-category fields, it is hard to write such queries + * since the entity information is stored in a nested object array. + * Also, the front end has all code/queries/ helper functions in place to + * rely on a single key per entity combo. Adding model id to forecast result + * to help the transition to multi-categorical field less painful. + */ + private final String modelId; // used when indexing exception or error or an empty result public AnomalyResult( @@ -255,12 +265,12 @@ public AnomalyResult( entity, user, schemaVersion, - modelId, taskId ); this.confidence = confidence; this.anomalyScore = anomalyScore; this.anomalyGrade = anomalyGrade; + this.modelId = modelId; this.approxAnomalyStartTime = approxAnomalyStartTime; this.relevantAttribution = relevantAttribution; this.pastValues = pastValues; @@ -422,6 +432,7 @@ public static AnomalyResult fromRawTRCFResult( public AnomalyResult(StreamInput input) throws IOException { super(input); + this.modelId = input.readOptionalString(); this.confidence = input.readDouble(); this.anomalyScore = input.readDouble(); this.anomalyGrade = input.readDouble(); @@ -502,7 +513,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(CommonName.ERROR_FIELD, error); } if (optionalEntity.isPresent()) { - xContentBuilder.field(CommonName.ENTITY_FIELD, optionalEntity.get()); + xContentBuilder.field(CommonName.ENTITY_KEY, optionalEntity.get()); } if (user != null) { xContentBuilder.field(CommonName.USER_FIELD, user); @@ -598,7 +609,7 @@ public static AnomalyResult parse(XContentParser parser) throws IOException { case CommonName.ERROR_FIELD: error = parser.text(); break; - case CommonName.ENTITY_FIELD: + case CommonName.ENTITY_KEY: entity = Entity.parse(parser); break; case CommonName.USER_FIELD: @@ -675,7 +686,8 @@ public boolean equals(Object o) { if (getClass() != o.getClass()) return false; AnomalyResult that = (AnomalyResult) o; - return Objects.equal(confidence, that.confidence) + return Objects.equal(modelId, that.modelId) + && Objects.equal(confidence, that.confidence) && Objects.equal(anomalyScore, that.anomalyScore) && Objects.equal(anomalyGrade, that.anomalyGrade) && Objects.equal(approxAnomalyStartTime, that.approxAnomalyStartTime) @@ -692,6 +704,7 @@ public int hashCode() { int result = super.hashCode(); result = prime * result + Objects .hashCode( + modelId, confidence, anomalyScore, anomalyGrade, @@ -710,6 +723,7 @@ public String toString() { return super.toString() + ", " + new ToStringBuilder(this) + .append("modelId", modelId) .append("confidence", confidence) .append("anomalyScore", anomalyScore) .append("anomalyGrade", anomalyGrade) @@ -757,6 +771,10 @@ public Double getThreshold() { return threshold; } + public String getModelId() { + return modelId; + } + /** * Anomaly result index consists of overwhelmingly (99.5%) zero-grade non-error documents. * This function exclude the majority case. @@ -772,6 +790,7 @@ public boolean isHighPriority() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + out.writeOptionalString(modelId); out.writeDouble(confidence); out.writeDouble(anomalyScore); out.writeDouble(anomalyGrade); diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfile.java b/src/main/java/org/opensearch/ad/model/DetectorProfile.java index 77418552e..7ffd8c9f6 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorProfile.java +++ b/src/main/java/org/opensearch/ad/model/DetectorProfile.java @@ -13,131 +13,28 @@ import java.io.IOException; -import org.apache.commons.lang.builder.EqualsBuilder; -import org.apache.commons.lang.builder.HashCodeBuilder; -import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.ConfigProfile; -public class DetectorProfile implements Writeable, ToXContentObject, Mergeable { - private DetectorState state; - private String error; - private ModelProfileOnNode[] modelProfile; - private int shingleSize; - private String coordinatingNode; - private long totalSizeInBytes; - private InitProgressProfile initProgress; - private Long totalEntities; - private Long activeEntities; - private ADTaskProfile adTaskProfile; - private long modelCount; +public class DetectorProfile extends ConfigProfile { - public XContentBuilder toXContent(XContentBuilder builder) throws IOException { - return toXContent(builder, ToXContent.EMPTY_PARAMS); - } - - public DetectorProfile(StreamInput in) throws IOException { - if (in.readBoolean()) { - this.state = in.readEnum(DetectorState.class); - } - - this.error = in.readOptionalString(); - this.modelProfile = in.readOptionalArray(ModelProfileOnNode::new, ModelProfileOnNode[]::new); - this.shingleSize = in.readOptionalInt(); - this.coordinatingNode = in.readOptionalString(); - this.totalSizeInBytes = in.readOptionalLong(); - this.totalEntities = in.readOptionalLong(); - this.activeEntities = in.readOptionalLong(); - if (in.readBoolean()) { - this.initProgress = new InitProgressProfile(in); - } - if (in.readBoolean()) { - this.adTaskProfile = new ADTaskProfile(in); - } - this.modelCount = in.readVLong(); - } - - private DetectorProfile() {} - - public static class Builder { - private DetectorState state = null; - private String error = null; - private ModelProfileOnNode[] modelProfile = null; - private int shingleSize = -1; - private String coordinatingNode = null; - private long totalSizeInBytes = -1; - private InitProgressProfile initProgress = null; - private Long totalEntities; - private Long activeEntities; + public static class Builder extends ConfigProfile.Builder { private ADTaskProfile adTaskProfile; - private long modelCount = 0; public Builder() {} - public Builder state(DetectorState state) { - this.state = state; - return this; - } - - public Builder error(String error) { - this.error = error; - return this; - } - - public Builder modelProfile(ModelProfileOnNode[] modelProfile) { - this.modelProfile = modelProfile; - return this; - } - - public Builder modelCount(long modelCount) { - this.modelCount = modelCount; - return this; - } - - public Builder shingleSize(int shingleSize) { - this.shingleSize = shingleSize; - return this; - } - - public Builder coordinatingNode(String coordinatingNode) { - this.coordinatingNode = coordinatingNode; - return this; - } - - public Builder totalSizeInBytes(long totalSizeInBytes) { - this.totalSizeInBytes = totalSizeInBytes; - return this; - } - - public Builder initProgress(InitProgressProfile initProgress) { - this.initProgress = initProgress; - return this; - } - - public Builder totalEntities(Long totalEntities) { - this.totalEntities = totalEntities; - return this; - } - - public Builder activeEntities(Long activeEntities) { - this.activeEntities = activeEntities; - return this; - } - - public Builder adTaskProfile(ADTaskProfile adTaskProfile) { + @Override + public Builder taskProfile(ADTaskProfile adTaskProfile) { this.adTaskProfile = adTaskProfile; return this; } + @Override public DetectorProfile build() { DetectorProfile profile = new DetectorProfile(); - profile.state = this.state; - profile.error = this.error; + profile.state = state; + profile.error = error; profile.modelProfile = modelProfile; profile.modelCount = modelCount; profile.shingleSize = shingleSize; @@ -146,320 +43,25 @@ public DetectorProfile build() { profile.initProgress = initProgress; profile.totalEntities = totalEntities; profile.activeEntities = activeEntities; - profile.adTaskProfile = adTaskProfile; + profile.taskProfile = adTaskProfile; return profile; } } - @Override - public void writeTo(StreamOutput out) throws IOException { - if (state == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - out.writeEnum(state); - } - - out.writeOptionalString(error); - out.writeOptionalArray(modelProfile); - out.writeOptionalInt(shingleSize); - out.writeOptionalString(coordinatingNode); - out.writeOptionalLong(totalSizeInBytes); - out.writeOptionalLong(totalEntities); - out.writeOptionalLong(activeEntities); - if (initProgress == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - initProgress.writeTo(out); - } - if (adTaskProfile == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - adTaskProfile.writeTo(out); - } - out.writeVLong(modelCount); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - XContentBuilder xContentBuilder = builder.startObject(); - - if (state != null) { - xContentBuilder.field(ADCommonName.STATE, state); - } - if (error != null) { - xContentBuilder.field(ADCommonName.ERROR, error); - } - if (modelProfile != null && modelProfile.length > 0) { - xContentBuilder.startArray(ADCommonName.MODELS); - for (ModelProfileOnNode profile : modelProfile) { - profile.toXContent(xContentBuilder, params); - } - xContentBuilder.endArray(); - } - if (shingleSize != -1) { - xContentBuilder.field(ADCommonName.SHINGLE_SIZE, shingleSize); - } - if (coordinatingNode != null && !coordinatingNode.isEmpty()) { - xContentBuilder.field(ADCommonName.COORDINATING_NODE, coordinatingNode); - } - if (totalSizeInBytes != -1) { - xContentBuilder.field(ADCommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); - } - if (initProgress != null) { - xContentBuilder.field(ADCommonName.INIT_PROGRESS, initProgress); - } - if (totalEntities != null) { - xContentBuilder.field(ADCommonName.TOTAL_ENTITIES, totalEntities); - } - if (activeEntities != null) { - xContentBuilder.field(ADCommonName.ACTIVE_ENTITIES, activeEntities); - } - if (adTaskProfile != null) { - xContentBuilder.field(ADCommonName.AD_TASK, adTaskProfile); - } - if (modelCount > 0) { - xContentBuilder.field(ADCommonName.MODEL_COUNT, modelCount); - } - return xContentBuilder.endObject(); - } - - public DetectorState getState() { - return state; - } - - public void setState(DetectorState state) { - this.state = state; - } - - public String getError() { - return error; - } - - public void setError(String error) { - this.error = error; - } - - public ModelProfileOnNode[] getModelProfile() { - return modelProfile; - } - - public void setModelProfile(ModelProfileOnNode[] modelProfile) { - this.modelProfile = modelProfile; - } - - public int getShingleSize() { - return shingleSize; - } - - public void setShingleSize(int shingleSize) { - this.shingleSize = shingleSize; - } - - public String getCoordinatingNode() { - return coordinatingNode; - } - - public void setCoordinatingNode(String coordinatingNode) { - this.coordinatingNode = coordinatingNode; - } - - public long getTotalSizeInBytes() { - return totalSizeInBytes; - } - - public void setTotalSizeInBytes(long totalSizeInBytes) { - this.totalSizeInBytes = totalSizeInBytes; - } - - public InitProgressProfile getInitProgress() { - return initProgress; - } - - public void setInitProgress(InitProgressProfile initProgress) { - this.initProgress = initProgress; - } - - public Long getTotalEntities() { - return totalEntities; - } - - public void setTotalEntities(Long totalEntities) { - this.totalEntities = totalEntities; - } - - public Long getActiveEntities() { - return activeEntities; - } - - public void setActiveEntities(Long activeEntities) { - this.activeEntities = activeEntities; - } - - public ADTaskProfile getAdTaskProfile() { - return adTaskProfile; - } - - public void setAdTaskProfile(ADTaskProfile adTaskProfile) { - this.adTaskProfile = adTaskProfile; - } - - public long getModelCount() { - return modelCount; - } - - public void setModelCount(long modelCount) { - this.modelCount = modelCount; - } - - @Override - public void merge(Mergeable other) { - if (this == other || other == null || getClass() != other.getClass()) { - return; - } - DetectorProfile otherProfile = (DetectorProfile) other; - if (otherProfile.getState() != null) { - this.state = otherProfile.getState(); - } - if (otherProfile.getError() != null) { - this.error = otherProfile.getError(); - } - if (otherProfile.getCoordinatingNode() != null) { - this.coordinatingNode = otherProfile.getCoordinatingNode(); - } - if (otherProfile.getShingleSize() != -1) { - this.shingleSize = otherProfile.getShingleSize(); - } - if (otherProfile.getModelProfile() != null) { - this.modelProfile = otherProfile.getModelProfile(); - } - if (otherProfile.getTotalSizeInBytes() != -1) { - this.totalSizeInBytes = otherProfile.getTotalSizeInBytes(); - } - if (otherProfile.getInitProgress() != null) { - this.initProgress = otherProfile.getInitProgress(); - } - if (otherProfile.getTotalEntities() != null) { - this.totalEntities = otherProfile.getTotalEntities(); - } - if (otherProfile.getActiveEntities() != null) { - this.activeEntities = otherProfile.getActiveEntities(); - } - if (otherProfile.getAdTaskProfile() != null) { - this.adTaskProfile = otherProfile.getAdTaskProfile(); - } - if (otherProfile.getModelCount() > 0) { - this.modelCount = otherProfile.getModelCount(); - } - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - if (obj instanceof DetectorProfile) { - DetectorProfile other = (DetectorProfile) obj; + public DetectorProfile() {} - EqualsBuilder equalsBuilder = new EqualsBuilder(); - if (state != null) { - equalsBuilder.append(state, other.state); - } - if (error != null) { - equalsBuilder.append(error, other.error); - } - if (modelProfile != null && modelProfile.length > 0) { - equalsBuilder.append(modelProfile, other.modelProfile); - } - if (shingleSize != -1) { - equalsBuilder.append(shingleSize, other.shingleSize); - } - if (coordinatingNode != null) { - equalsBuilder.append(coordinatingNode, other.coordinatingNode); - } - if (totalSizeInBytes != -1) { - equalsBuilder.append(totalSizeInBytes, other.totalSizeInBytes); - } - if (initProgress != null) { - equalsBuilder.append(initProgress, other.initProgress); - } - if (totalEntities != null) { - equalsBuilder.append(totalEntities, other.totalEntities); - } - if (activeEntities != null) { - equalsBuilder.append(activeEntities, other.activeEntities); - } - if (adTaskProfile != null) { - equalsBuilder.append(adTaskProfile, other.adTaskProfile); - } - if (modelCount > 0) { - equalsBuilder.append(modelCount, other.modelCount); - } - return equalsBuilder.isEquals(); - } - return false; + public DetectorProfile(StreamInput in) throws IOException { + super(in); } @Override - public int hashCode() { - return new HashCodeBuilder() - .append(state) - .append(error) - .append(modelProfile) - .append(shingleSize) - .append(coordinatingNode) - .append(totalSizeInBytes) - .append(initProgress) - .append(totalEntities) - .append(activeEntities) - .append(adTaskProfile) - .append(modelCount) - .toHashCode(); + protected ADTaskProfile createTaskProfile(StreamInput in) throws IOException { + return new ADTaskProfile(in); } @Override - public String toString() { - ToStringBuilder toStringBuilder = new ToStringBuilder(this); - - if (state != null) { - toStringBuilder.append(ADCommonName.STATE, state); - } - if (error != null) { - toStringBuilder.append(ADCommonName.ERROR, error); - } - if (modelProfile != null && modelProfile.length > 0) { - toStringBuilder.append(modelProfile); - } - if (shingleSize != -1) { - toStringBuilder.append(ADCommonName.SHINGLE_SIZE, shingleSize); - } - if (coordinatingNode != null) { - toStringBuilder.append(ADCommonName.COORDINATING_NODE, coordinatingNode); - } - if (totalSizeInBytes != -1) { - toStringBuilder.append(ADCommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); - } - if (initProgress != null) { - toStringBuilder.append(ADCommonName.INIT_PROGRESS, initProgress); - } - if (totalEntities != null) { - toStringBuilder.append(ADCommonName.TOTAL_ENTITIES, totalEntities); - } - if (activeEntities != null) { - toStringBuilder.append(ADCommonName.ACTIVE_ENTITIES, activeEntities); - } - if (adTaskProfile != null) { - toStringBuilder.append(ADCommonName.AD_TASK, adTaskProfile); - } - if (modelCount > 0) { - toStringBuilder.append(ADCommonName.MODEL_COUNT, modelCount); - } - return toStringBuilder.toString(); + protected String getTaskFieldName() { + return ADCommonName.AD_TASK; } } diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfileName.java b/src/main/java/org/opensearch/ad/model/DetectorProfileName.java deleted file mode 100644 index 443066ac8..000000000 --- a/src/main/java/org/opensearch/ad/model/DetectorProfileName.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.model; - -import java.util.Collection; -import java.util.Set; - -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.timeseries.Name; - -public enum DetectorProfileName implements Name { - STATE(ADCommonName.STATE), - ERROR(ADCommonName.ERROR), - COORDINATING_NODE(ADCommonName.COORDINATING_NODE), - SHINGLE_SIZE(ADCommonName.SHINGLE_SIZE), - TOTAL_SIZE_IN_BYTES(ADCommonName.TOTAL_SIZE_IN_BYTES), - MODELS(ADCommonName.MODELS), - INIT_PROGRESS(ADCommonName.INIT_PROGRESS), - TOTAL_ENTITIES(ADCommonName.TOTAL_ENTITIES), - ACTIVE_ENTITIES(ADCommonName.ACTIVE_ENTITIES), - AD_TASK(ADCommonName.AD_TASK); - - private String name; - - DetectorProfileName(String name) { - this.name = name; - } - - /** - * Get profile name - * - * @return name - */ - @Override - public String getName() { - return name; - } - - public static DetectorProfileName getName(String name) { - switch (name) { - case ADCommonName.STATE: - return STATE; - case ADCommonName.ERROR: - return ERROR; - case ADCommonName.COORDINATING_NODE: - return COORDINATING_NODE; - case ADCommonName.SHINGLE_SIZE: - return SHINGLE_SIZE; - case ADCommonName.TOTAL_SIZE_IN_BYTES: - return TOTAL_SIZE_IN_BYTES; - case ADCommonName.MODELS: - return MODELS; - case ADCommonName.INIT_PROGRESS: - return INIT_PROGRESS; - case ADCommonName.TOTAL_ENTITIES: - return TOTAL_ENTITIES; - case ADCommonName.ACTIVE_ENTITIES: - return ACTIVE_ENTITIES; - case ADCommonName.AD_TASK: - return AD_TASK; - default: - throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); - } - } - - public static Set getNames(Collection names) { - return Name.getNameFromCollection(names, DetectorProfileName::getName); - } -} diff --git a/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java b/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java index 7eeb02e6c..2d58c2be8 100644 --- a/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java @@ -13,6 +13,8 @@ import java.util.List; +import org.opensearch.timeseries.model.Mergeable; + public class EntityAnomalyResult implements Mergeable { private List anomalyResults; diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java similarity index 70% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java rename to src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java index 05f9480a7..f2e5e3d8a 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java @@ -16,28 +16,27 @@ import java.time.Clock; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import java.util.Random; +import java.util.function.Function; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; -public class CheckpointMaintainWorker extends ScheduledWorker { - private static final Logger LOG = LogManager.getLogger(CheckpointMaintainWorker.class); - public static final String WORKER_NAME = "checkpoint-maintain"; +public class ADCheckpointMaintainWorker extends CheckpointMaintainWorker { + public static final String WORKER_NAME = "ad-checkpoint-maintain"; - private CheckPointMaintainRequestAdapter adapter; - - public CheckpointMaintainWorker( + public ADCheckpointMaintainWorker( long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, @@ -51,10 +50,10 @@ public CheckpointMaintainWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, - CheckpointWriteWorker checkpointWriteQueue, + ADCheckpointWriteWorker checkpointWriteQueue, Duration stateTtl, NodeStateManager nodeStateManager, - CheckPointMaintainRequestAdapter adapter + Function> converter ) { super( WORKER_NAME, @@ -65,6 +64,7 @@ public CheckpointMaintainWorker( random, adCircuitBreakerService, threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, settings, maxQueuedTaskRatio, clock, @@ -73,7 +73,9 @@ public CheckpointMaintainWorker( maintenanceFreqConstant, checkpointWriteQueue, stateTtl, - nodeStateManager + nodeStateManager, + converter, + AnalysisType.AD ); this.batchSize = AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.get(settings); @@ -87,18 +89,5 @@ public CheckpointMaintainWorker( AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, it -> this.expectedExecutionTimeInMilliSecsPerRequest = it ); - this.adapter = adapter; - } - - @Override - protected List transformRequests(List requests) { - List allRequests = new ArrayList<>(); - for (CheckpointMaintainRequest request : requests) { - Optional converted = adapter.convert(request); - if (!converted.isEmpty()) { - allRequests.add(converted.get()); - } - } - return allRequests; } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java new file mode 100644 index 000000000..ae25071d3 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.stats.StatNames; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: + * a). If a checkpoint is not found, we forward that request to the cold start queue. + * b). When a request gets errors, the queue does not change its expiry time and puts + * that request to the end of the queue and automatically retries them before they expire. + * c) When a checkpoint is found, we load that point to memory and score the input + * data point and save the result if a complete model exists. Otherwise, we enqueue + * the sample. If we can host that model in memory (e.g., there is enough memory), + * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the + * updated checkpoint back to disk. + * + */ +public class ADCheckpointReadWorker extends + CheckpointReadWorker { + public static final String WORKER_NAME = "ad-checkpoint-read"; + + public ADCheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADModelManager modelManager, + ADCheckpointDao checkpointDao, + ADColdStartWorker entityColdStartQueue, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + Provider cacheProvider, + Duration stateTtl, + ADCheckpointWriteWorker checkpointWriteQueue, + ADStats adStats, + ADSaveResultStrategy resultWriteWorker + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + executionTtl, + modelManager, + checkpointDao, + entityColdStartQueue, + stateManager, + indexUtil, + cacheProvider, + stateTtl, + checkpointWriteQueue, + adStats, + AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ADCommonName.CHECKPOINT_INDEX_NAME, + StatNames.AD_MODEL_CORRUTPION_COUNT, + AnalysisType.AD, + resultWriteWorker + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java new file mode 100644 index 000000000..fd0bc3a66 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADCheckpointWriteWorker extends + CheckpointWriteWorker { + public static final String WORKER_NAME = "ad-checkpoint-write"; + + public ADCheckpointWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADCheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager adNodeStateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + adNodeStateManager, + checkpoint, + indexName, + checkpointInterval, + AnalysisType.AD + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java similarity index 68% rename from src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java rename to src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java index 701fc25d4..9d3037db1 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java @@ -16,17 +16,27 @@ import java.time.Clock; import java.time.Duration; -import java.util.List; import java.util.Random; -import java.util.stream.Collectors; -import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; /** * A queue slowly releasing low-priority requests to CheckpointReadQueue @@ -43,10 +53,11 @@ * entity requests.  * */ -public class ColdEntityWorker extends ScheduledWorker { - public static final String WORKER_NAME = "cold-entity"; +public class ADColdEntityWorker extends + ColdEntityWorker { + public static final String WORKER_NAME = "ad-cold-entity"; - public ColdEntityWorker( + public ADColdEntityWorker( long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, @@ -60,7 +71,7 @@ public ColdEntityWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, - CheckpointReadWorker checkpointReadQueue, + ADCheckpointReadWorker checkpointReadQueue, Duration stateTtl, NodeStateManager nodeStateManager ) { @@ -73,6 +84,7 @@ public ColdEntityWorker( random, adCircuitBreakerService, threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, settings, maxQueuedTaskRatio, clock, @@ -81,25 +93,10 @@ public ColdEntityWorker( maintenanceFreqConstant, checkpointReadQueue, stateTtl, - nodeStateManager + nodeStateManager, + AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnalysisType.AD ); - - this.batchSize = AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, it -> this.batchSize = it); - - this.expectedExecutionTimeInMilliSecsPerRequest = AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS - .get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer( - AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, - it -> this.expectedExecutionTimeInMilliSecsPerRequest = it - ); - } - - @Override - protected List transformRequests(List requests) { - // guarantee we only send low priority requests - return requests.stream().filter(request -> request.priority == RequestPriority.LOW).collect(Collectors.toList()); } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java new file mode 100644 index 000000000..b6686d0a7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java @@ -0,0 +1,154 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.util.ParseUtils; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A queue for HCAD model training (a.k.a. cold start). As model training is a + * pretty expensive operation, we pull cold start requests from the queue in a + * serial fashion. Each detector has an equal chance of being pulled. The equal + * probability is achieved by putting model training requests for different + * detectors into different segments and pulling requests from segments in a + * round-robin fashion. + * + */ + +// suppress warning due to the use of generic type ModelState +public class ADColdStartWorker extends + ColdStartWorker { + public static final String WORKER_NAME = "ad-cold-start"; + + public ADColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADEntityColdStart entityColdStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + ADPriorityCache cacheProvider, + ADModelManager modelManager, + ADSaveResultStrategy saveStrategy + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_ENTITY_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + entityColdStarter, + stateTtl, + nodeStateManager, + cacheProvider, + AnalysisType.AD, + modelManager, + saveStrategy + ); + } + + @Override + protected ModelState createEmptyState(FeatureRequest request, String modelId, String configId) { + return new ModelState( + null, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + new Sample(), + request.getEntity(), + new ArrayDeque<>() + ); + } + + @Override + protected AnomalyResult createIndexableResult( + Config config, + String taskId, + String modelId, + Entry entry, + Optional entity + ) { + return new AnomalyResult( + config.getId(), + taskId, + ParseUtils.getFeatureData(entry.getValue(), config), + Instant.ofEpochMilli(entry.getKey() - config.getIntervalInMilliseconds()), + Instant.ofEpochMilli(entry.getKey()), + Instant.now(), + Instant.now(), + "", + entity, + config.getUser(), + config.getSchemaVersion(), + modelId + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java new file mode 100644 index 000000000..912396ebd --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.io.IOException; + +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ADResultWriteRequest extends ResultWriteRequest { + + public ADResultWriteRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + AnomalyResult result, + String resultIndex + ) { + super(expirationEpochMs, detectorId, priority, result, resultIndex); + } + + public ADResultWriteRequest(StreamInput in) throws IOException { + super(in, AnomalyResult::new); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java new file mode 100644 index 000000000..b57e99f1c --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; + +public class ADResultWriteWorker extends + ResultWriteWorker { + public static final String WORKER_NAME = "ad-result-write"; + + public ADResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADIndexMemoryPressureAwareResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager, + resultHandler, + xContentRegistry, + AnomalyResult::parse, + AnalysisType.AD + ); + } + + @Override + protected ADResultBulkRequest toBatchRequest(List toProcess) { + final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); + for (ADResultWriteRequest request : toProcess) { + bulkRequest.add(request); + } + return bulkRequest; + } + + @Override + protected ADResultWriteRequest createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + AnomalyResult result, + String resultIndex + ) { + return new ADResultWriteRequest(expirationEpochMs, configId, priority, result, resultIndex); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java b/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java new file mode 100644 index 000000000..aeb265072 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADSaveResultStrategy.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.util.ParseUtils; + +public class ADSaveResultStrategy implements SaveResultStrategy { + private int resultMappingVersion; + private ADResultWriteWorker resultWriteWorker; + + public ADSaveResultStrategy(int resultMappingVersion, ADResultWriteWorker resultWriteWorker) { + this.resultMappingVersion = resultMappingVersion; + this.resultWriteWorker = resultWriteWorker; + } + + @Override + public void saveResult(ThresholdingResult result, Config config, FeatureRequest origRequest, String modelId) { + // result.getRcfScore() = 0 means the model is not initialized + // result.getGrade() = 0 means it is not an anomaly + saveResult( + result, + config, + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()), + modelId, + origRequest.getCurrentFeature(), + origRequest.getEntity(), + origRequest.getTaskId() + ); + } + + @Override + public void saveResult( + ThresholdingResult result, + Config config, + Instant dataStart, + Instant dataEnd, + String modelId, + double[] currentData, + Optional entity, + String taskId + ) { + // result.getRcfScore() = 0 means the model is not initialized + // result.getGrade() = 0 means it is not an anomaly + if (result != null && result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + dataStart, + dataEnd, + Instant.now(), + Instant.now(), + ParseUtils.getFeatureData(currentData, config), + entity, + resultMappingVersion, + modelId, + taskId, + null + ); + + for (AnomalyResult r : indexableResults) { + saveResult(r, config); + } + } + } + + @Override + public void saveResult(AnomalyResult result, Config config) { + resultWriteWorker + .put( + new ADResultWriteRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + config.getId(), + result.getAnomalyGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, + result, + config.getCustomResultIndex() + ) + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java deleted file mode 100644 index 72011e156..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY; - -import java.time.Clock; -import java.time.Duration; -import java.util.ArrayDeque; -import java.util.Locale; -import java.util.Optional; -import java.util.Random; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.breaker.CircuitBreakerService; -import org.opensearch.timeseries.util.ExceptionUtil; - -/** - * A queue for HCAD model training (a.k.a. cold start). As model training is a - * pretty expensive operation, we pull cold start requests from the queue in a - * serial fashion. Each detector has an equal chance of being pulled. The equal - * probability is achieved by putting model training requests for different - * detectors into different segments and pulling requests from segments in a - * round-robin fashion. - * - */ -public class EntityColdStartWorker extends SingleRequestWorker { - private static final Logger LOG = LogManager.getLogger(EntityColdStartWorker.class); - public static final String WORKER_NAME = "cold-start"; - - private final EntityColdStarter entityColdStarter; - private final CacheProvider cacheProvider; - - public EntityColdStartWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, - Setting maxHeapPercentForQueueSetting, - ClusterService clusterService, - Random random, - CircuitBreakerService adCircuitBreakerService, - ThreadPool threadPool, - Settings settings, - float maxQueuedTaskRatio, - Clock clock, - float mediumSegmentPruneRatio, - float lowSegmentPruneRatio, - int maintenanceFreqConstant, - Duration executionTtl, - EntityColdStarter entityColdStarter, - Duration stateTtl, - NodeStateManager nodeStateManager, - CacheProvider cacheProvider - ) { - super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, - maxHeapPercentForQueueSetting, - clusterService, - random, - adCircuitBreakerService, - threadPool, - settings, - maxQueuedTaskRatio, - clock, - mediumSegmentPruneRatio, - lowSegmentPruneRatio, - maintenanceFreqConstant, - AD_ENTITY_COLD_START_QUEUE_CONCURRENCY, - executionTtl, - stateTtl, - nodeStateManager - ); - this.entityColdStarter = entityColdStarter; - this.cacheProvider = cacheProvider; - } - - @Override - protected void executeRequest(EntityRequest coldStartRequest, ActionListener listener) { - String detectorId = coldStartRequest.getId(); - - Optional modelId = coldStartRequest.getModelId(); - - if (false == modelId.isPresent()) { - String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); - LOG.warn(error); - listener.onFailure(new RuntimeException(error)); - return; - } - - ModelState modelState = new ModelState<>( - new EntityModel(coldStartRequest.getEntity(), new ArrayDeque<>(), null), - modelId.get(), - detectorId, - ModelType.ENTITY.getName(), - clock, - 0 - ); - - ActionListener coldStartListener = ActionListener.wrap(r -> { - nodeStateManager.getConfig(detectorId, AnalysisType.AD, ActionListener.wrap(detectorOptional -> { - try { - if (!detectorOptional.isPresent()) { - LOG - .error( - new ParameterizedMessage( - "fail to load trained model [{}] to cache due to the detector not being found.", - modelState.getModelId() - ) - ); - return; - } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - EntityModel model = modelState.getModel(); - // load to cache if cold start succeeds - if (model != null && model.getTrcf() != null) { - cacheProvider.get().hostIfPossible(detector, modelState); - } - } finally { - listener.onResponse(null); - } - }, listener::onFailure)); - - }, e -> { - try { - if (ExceptionUtil.isOverloaded(e)) { - LOG.error("OpenSearch is overloaded"); - setCoolDownStart(); - } - nodeStateManager.setException(detectorId, e); - } finally { - listener.onFailure(e); - } - }); - - entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, coldStartListener); - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java deleted file mode 100644 index 875974dbb..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import org.opensearch.timeseries.model.Entity; - -public class EntityFeatureRequest extends EntityRequest { - private final double[] currentFeature; - private final long dataStartTimeMillis; - - public EntityFeatureRequest( - long expirationEpochMs, - String detectorId, - RequestPriority priority, - Entity entity, - double[] currentFeature, - long dataStartTimeMs - ) { - super(expirationEpochMs, detectorId, priority, entity); - this.currentFeature = currentFeature; - this.dataStartTimeMillis = dataStartTimeMs; - } - - public double[] getCurrentFeature() { - return currentFeature; - } - - public long getDataStartTimeMillis() { - return dataStartTimeMillis; - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java deleted file mode 100644 index 7acf2652a..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import java.util.Optional; - -import org.opensearch.timeseries.model.Entity; - -public class EntityRequest extends QueuedRequest { - private final Entity entity; - - /** - * - * @param expirationEpochMs Expiry time of the request - * @param detectorId Detector Id - * @param priority the entity's priority - * @param entity the entity's attributes - */ - public EntityRequest(long expirationEpochMs, String detectorId, RequestPriority priority, Entity entity) { - super(expirationEpochMs, detectorId, priority); - this.entity = entity; - } - - public Entity getEntity() { - return entity; - } - - public Optional getModelId() { - return entity.getModelId(detectorId); - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java deleted file mode 100644 index a25bf3924..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import java.io.IOException; - -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; - -public class ResultWriteRequest extends QueuedRequest implements Writeable { - private final AnomalyResult result; - // If resultIndex is null, result will be stored in default result index. - private final String resultIndex; - - public ResultWriteRequest( - long expirationEpochMs, - String detectorId, - RequestPriority priority, - AnomalyResult result, - String resultIndex - ) { - super(expirationEpochMs, detectorId, priority); - this.result = result; - this.resultIndex = resultIndex; - } - - public ResultWriteRequest(StreamInput in) throws IOException { - this.result = new AnomalyResult(in); - this.resultIndex = in.readOptionalString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - result.writeTo(out); - out.writeOptionalString(resultIndex); - } - - public AnomalyResult getResult() { - return result; - } - - public String getCustomResultIndex() { - return resultIndex; - } -} diff --git a/src/main/java/org/opensearch/ad/rest/AbstractADSearchAction.java b/src/main/java/org/opensearch/ad/rest/AbstractADSearchAction.java new file mode 100644 index 000000000..ef901f40c --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/AbstractADSearchAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest; + +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.timeseries.AbstractSearchAction; + +public abstract class AbstractADSearchAction extends AbstractSearchAction { + + public AbstractADSearchAction( + List urlPaths, + List> deprecatedPaths, + String index, + Class clazz, + ActionType actionType + ) { + super(urlPaths, deprecatedPaths, index, clazz, actionType, ADEnabledSetting::isADEnabled, ADCommonMessages.DISABLED_ERR_MSG); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java index ee0d410f5..4a10b3ad9 100644 --- a/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java @@ -18,6 +18,7 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.DETECTION_WINDOW_DELAY; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; +import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -31,6 +32,7 @@ public abstract class AbstractAnomalyDetectorAction extends BaseRestHandler { protected volatile Integer maxSingleEntityDetectors; protected volatile Integer maxMultiEntityDetectors; protected volatile Integer maxAnomalyFeatures; + protected volatile Integer maxCategoricalFields; public AbstractAnomalyDetectorAction(Settings settings, ClusterService clusterService) { this.requestTimeout = AD_REQUEST_TIMEOUT.get(settings); @@ -39,6 +41,7 @@ public AbstractAnomalyDetectorAction(Settings settings, ClusterService clusterSe this.maxSingleEntityDetectors = AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); this.maxMultiEntityDetectors = AD_MAX_HC_ANOMALY_DETECTORS.get(settings); this.maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); + this.maxCategoricalFields = ADNumericSetting.maxCategoricalFields(); // TODO: will add more cluster setting consumer later // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_REQUEST_TIMEOUT, it -> requestTimeout = it); diff --git a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java index 175ac02e7..14ef4c652 100644 --- a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java @@ -12,10 +12,7 @@ package org.opensearch.ad.rest; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; -import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; -import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; @@ -26,25 +23,23 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.AnomalyDetectorJobAction; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.rest.RestJobAction; +import org.opensearch.timeseries.transport.JobRequest; import com.google.common.collect.ImmutableList; /** * This class consists of the REST handler to handle request to start/stop AD job. */ -public class RestAnomalyDetectorJobAction extends BaseRestHandler { +public class RestAnomalyDetectorJobAction extends RestJobAction { public static final String AD_JOB_ACTION = "anomaly_detector_job_action"; private volatile TimeValue requestTimeout; @@ -66,40 +61,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } String detectorId = request.param(DETECTOR_ID); - long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); - long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); boolean historical = request.paramAsBoolean("historical", false); String rawPath = request.rawPath(); - DateRange detectionDateRange = parseDetectionDateRange(request); + DateRange detectionDateRange = parseInputDateRange(request); - AnomalyDetectorJobRequest anomalyDetectorJobRequest = new AnomalyDetectorJobRequest( - detectorId, - detectionDateRange, - historical, - seqNo, - primaryTerm, - rawPath - ); + JobRequest anomalyDetectorJobRequest = new JobRequest(detectorId, detectionDateRange, historical, rawPath); return channel -> client .execute(AnomalyDetectorJobAction.INSTANCE, anomalyDetectorJobRequest, new RestToXContentListener<>(channel)); } - private DateRange parseDetectionDateRange(RestRequest request) throws IOException { - if (!request.hasContent()) { - return null; - } - XContentParser parser = request.contentParser(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - DateRange dateRange = DateRange.parse(parser); - return dateRange; - } - - @Override - public List routes() { - return ImmutableList.of(); - } - @Override public List replacedRoutes() { return ImmutableList diff --git a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java index b7a3aae6c..1ad4f0a9a 100644 --- a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java @@ -17,18 +17,15 @@ import java.util.List; import java.util.Locale; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.rest.handler.AnomalyDetectorActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; -import org.opensearch.ad.transport.DeleteAnomalyDetectorRequest; import org.opensearch.client.node.NodeClient; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.DeleteConfigRequest; import com.google.common.collect.ImmutableList; @@ -39,9 +36,6 @@ public class RestDeleteAnomalyDetectorAction extends BaseRestHandler { public static final String DELETE_ANOMALY_DETECTOR_ACTION = "delete_anomaly_detector"; - private static final Logger logger = LogManager.getLogger(RestDeleteAnomalyDetectorAction.class); - private final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); - public RestDeleteAnomalyDetectorAction() {} @Override @@ -56,7 +50,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } String detectorId = request.param(DETECTOR_ID); - DeleteAnomalyDetectorRequest deleteAnomalyDetectorRequest = new DeleteAnomalyDetectorRequest(detectorId); + DeleteConfigRequest deleteAnomalyDetectorRequest = new DeleteConfigRequest(detectorId); return channel -> client .execute(DeleteAnomalyDetectorAction.INSTANCE, deleteAnomalyDetectorRequest, new RestToXContentListener<>(channel)); } diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java index 315ba0410..8ef0bb473 100644 --- a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -18,24 +18,20 @@ import java.io.IOException; import java.util.List; import java.util.Locale; -import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.GetAnomalyDetectorAction; -import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.Strings; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestActions; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; @@ -66,7 +62,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean returnJob = request.paramAsBoolean("job", false); boolean returnTask = request.paramAsBoolean("task", false); boolean all = request.paramAsBoolean("_all", false); - GetAnomalyDetectorRequest getConfigRequest = new GetAnomalyDetectorRequest( + GetConfigRequest getConfigRequest = new GetConfigRequest( detectorId, RestActions.parseVersion(request), returnJob, @@ -74,7 +70,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli typesStr, rawPath, all, - buildEntity(request, detectorId) + RestHandlerUtils.buildEntity(request, detectorId) ); return channel -> client.execute(GetAnomalyDetectorAction.INSTANCE, getConfigRequest, new RestToXContentListener<>(channel)); @@ -137,35 +133,4 @@ public List replacedRoutes() { ) ); } - - private Entity buildEntity(RestRequest request, String detectorId) throws IOException { - if (Strings.isEmpty(detectorId)) { - throw new IllegalStateException(ADCommonMessages.AD_ID_MISSING_MSG); - } - - String entityName = request.param(ADCommonName.CATEGORICAL_FIELD); - String entityValue = request.param(CommonName.ENTITY_KEY); - - if (entityName != null && entityValue != null) { - // single-stream profile request: - // GET _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= - return Entity.createSingleAttributeEntity(entityName, entityValue); - } else if (request.hasContent()) { - /* HCAD profile request: - * GET _plugins/_anomaly_detection/detectors//_profile/init_progress - * { - * "entity": [{ - * "name": "clientip", - * "value": "13.24.0.0" - * }] - * } - */ - Optional entity = Entity.fromJsonObject(request.contentParser()); - if (entity.isPresent()) { - return entity.get(); - } - } - // not a valid profile request with correct entity information - return null; - } } diff --git a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java index 6231d8e11..66981d54c 100644 --- a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java @@ -94,7 +94,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli requestTimeout, maxSingleEntityDetectors, maxMultiEntityDetectors, - maxAnomalyFeatures + maxAnomalyFeatures, + maxCategoricalFields ); return channel -> client diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java index 6a1bfce58..a858d46aa 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java @@ -22,7 +22,7 @@ /** * This class consists of the REST handler to search AD tasks. */ -public class RestSearchADTasksAction extends AbstractSearchAction { +public class RestSearchADTasksAction extends AbstractADSearchAction { private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/tasks/_search"; private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/tasks/_search"; diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java index 214fa8b2c..a5c1551e7 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java @@ -22,7 +22,7 @@ /** * This class consists of the REST handler to search anomaly detectors. */ -public class RestSearchAnomalyDetectorAction extends AbstractSearchAction { +public class RestSearchAnomalyDetectorAction extends AbstractADSearchAction { private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/_search"; private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/_search"; diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java index 1f2ade113..0b7f748c7 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java @@ -23,12 +23,12 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.SearchAnomalyDetectorInfoAction; -import org.opensearch.ad.transport.SearchAnomalyDetectorInfoRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.SearchConfigInfoRequest; import com.google.common.collect.ImmutableList; @@ -54,7 +54,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, org.opensearch String detectorName = request.param("name", null); String rawPath = request.rawPath(); - SearchAnomalyDetectorInfoRequest searchAnomalyDetectorInfoRequest = new SearchAnomalyDetectorInfoRequest(detectorName, rawPath); + SearchConfigInfoRequest searchAnomalyDetectorInfoRequest = new SearchConfigInfoRequest(detectorName, rawPath); return channel -> client .execute(SearchAnomalyDetectorInfoAction.INSTANCE, searchAnomalyDetectorInfoRequest, new RestToXContentListener<>(channel)); } diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java index 9db521595..b014ca753 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java @@ -35,7 +35,7 @@ /** * This class consists of the REST handler to search anomaly results. */ -public class RestSearchAnomalyResultAction extends AbstractSearchAction { +public class RestSearchAnomalyResultAction extends AbstractADSearchAction { private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/results/_search"; private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/results/_search"; public static final String SEARCH_ANOMALY_RESULT_ACTION = "search_anomaly_result"; diff --git a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java index 65b936e98..ddceab44a 100644 --- a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java @@ -14,47 +14,36 @@ import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BASE_URI; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE; -import java.util.Arrays; -import java.util.HashSet; import java.util.List; -import java.util.Set; -import java.util.TreeSet; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.transport.ADStatsRequest; import org.opensearch.ad.transport.StatsAnomalyDetectorAction; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.common.Strings; -import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.rest.RestStatsAction; +import org.opensearch.timeseries.transport.StatsRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.collect.ImmutableList; /** - * RestStatsAnomalyDetectorAction consists of the REST handler to get the stats from the anomaly detector plugin. + * RestStatsAnomalyDetectorAction consists of the REST handler to get the stats from AD. */ -public class RestStatsAnomalyDetectorAction extends BaseRestHandler { +public class RestStatsAnomalyDetectorAction extends RestStatsAction { private static final String STATS_ANOMALY_DETECTOR_ACTION = "stats_anomaly_detector"; - private ADStats adStats; - private ClusterService clusterService; - private DiscoveryNodeFilterer nodeFilter; /** * Constructor * - * @param adStats ADStats object + * @param timeSeriesStats TimeSeriesStats object * @param nodeFilter util class to get eligible data nodes */ - public RestStatsAnomalyDetectorAction(ADStats adStats, DiscoveryNodeFilterer nodeFilter) { - this.adStats = adStats; - this.nodeFilter = nodeFilter; + public RestStatsAnomalyDetectorAction(ADStats timeSeriesStats, DiscoveryNodeFilterer nodeFilter) { + super(timeSeriesStats, nodeFilter); } @Override @@ -67,64 +56,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (!ADEnabledSetting.isADEnabled()) { throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); } - ADStatsRequest adStatsRequest = getRequest(request); + StatsRequest adStatsRequest = getRequest(request); return channel -> client.execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest, new RestToXContentListener<>(channel)); } - /** - * Creates a ADStatsRequest from a RestRequest - * - * @param request RestRequest - * @return ADStatsRequest Request containing stats to be retrieved - */ - private ADStatsRequest getRequest(RestRequest request) { - // parse the nodes the user wants to query the stats for - String nodesIdsStr = request.param("nodeId"); - Set validStats = adStats.getStats().keySet(); - - ADStatsRequest adStatsRequest = null; - if (!Strings.isEmpty(nodesIdsStr)) { - String[] nodeIdsArr = nodesIdsStr.split(","); - adStatsRequest = new ADStatsRequest(nodeIdsArr); - } else { - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - adStatsRequest = new ADStatsRequest(dataNodes); - } - - adStatsRequest.timeout(request.param("timeout")); - - // parse the stats the user wants to see - HashSet statsSet = null; - String statsStr = request.param("stat"); - if (!Strings.isEmpty(statsStr)) { - statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); - } - - if (statsSet == null) { - adStatsRequest.addAll(validStats); // retrieve all stats if none are specified - } else if (statsSet.size() == 1 && statsSet.contains(ADStatsRequest.ALL_STATS_KEY)) { - adStatsRequest.addAll(validStats); - } else if (statsSet.contains(ADStatsRequest.ALL_STATS_KEY)) { - throw new IllegalArgumentException( - "Request " + request.path() + " contains " + ADStatsRequest.ALL_STATS_KEY + " and individual stats" - ); - } else { - Set invalidStats = new TreeSet<>(); - for (String stat : statsSet) { - if (validStats.contains(stat)) { - adStatsRequest.addStat(stat); - } else { - invalidStats.add(stat); - } - } - - if (!invalidStats.isEmpty()) { - throw new IllegalArgumentException(unrecognized(request, invalidStats, adStatsRequest.getStatsToBeRetrieved(), "stat")); - } - } - return adStatsRequest; - } - @Override public List routes() { return ImmutableList.of(); diff --git a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java index e728889f8..91d72dcf9 100644 --- a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java @@ -16,35 +16,25 @@ import static org.opensearch.timeseries.util.RestHandlerUtils.VALIDATE; import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Locale; -import java.util.Set; -import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; -import org.opensearch.ad.transport.ValidateAnomalyDetectorRequest; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.rest.BaseRestHandler; -import org.opensearch.rest.BytesRestResponse; -import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.rest.RestValidateAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; import com.google.common.collect.ImmutableList; @@ -54,14 +44,18 @@ public class RestValidateAnomalyDetectorAction extends AbstractAnomalyDetectorAction { private static final String VALIDATE_ANOMALY_DETECTOR_ACTION = "validate_anomaly_detector_action"; - public static final Set ALL_VALIDATION_ASPECTS_STRS = Arrays - .asList(ValidationAspect.values()) - .stream() - .map(aspect -> aspect.getName()) - .collect(Collectors.toSet()); + private RestValidateAction validateAction; public RestValidateAnomalyDetectorAction(Settings settings, ClusterService clusterService) { super(settings, clusterService); + this.validateAction = new RestValidateAction( + AnalysisType.FORECAST, + maxSingleEntityDetectors, + maxMultiEntityDetectors, + maxAnomalyFeatures, + maxCategoricalFields, + requestTimeout + ); } @Override @@ -84,66 +78,35 @@ public List routes() { ); } - protected void sendAnomalyDetectorValidationParseResponse(DetectorValidationIssue issue, RestChannel channel) throws IOException { - try { - BytesRestResponse restResponse = new BytesRestResponse( - RestStatus.OK, - new ValidateAnomalyDetectorResponse(issue).toXContent(channel.newBuilder()) - ); - channel.sendResponse(restResponse); - } catch (Exception e) { - channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); - } - } - - private Boolean validationTypesAreAccepted(String validationType) { - Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); - return (!Collections.disjoint(typesInRequest, ALL_VALIDATION_ASPECTS_STRS)); - } - @Override protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { if (!ADEnabledSetting.isADEnabled()) { throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); } + XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + // we have to get the param from a subclass of BaseRestHandler. Otherwise, we cannot parse the type out of request params String typesStr = request.param(TYPE); - // if type param isn't blank and isn't a part of possible validation types throws exception - if (!StringUtils.isBlank(typesStr)) { - if (!validationTypesAreAccepted(typesStr)) { - throw new IllegalStateException(ADCommonMessages.NOT_EXISTENT_VALIDATION_TYPE); - } - } - return channel -> { - AnomalyDetector detector; try { - detector = AnomalyDetector.parse(parser); + ValidateConfigRequest validateAnomalyDetectorRequest = validateAction.prepareRequest(request, client, typesStr); + client + .execute(ValidateAnomalyDetectorAction.INSTANCE, validateAnomalyDetectorRequest, new RestToXContentListener<>(channel)); } catch (Exception ex) { if (ex instanceof ValidationException) { - ValidationException ADException = (ValidationException) ex; - DetectorValidationIssue issue = new DetectorValidationIssue( - ADException.getAspect(), - ADException.getType(), - ADException.getMessage() + ValidationException adException = (ValidationException) ex; + ConfigValidationIssue issue = new ConfigValidationIssue( + adException.getAspect(), + adException.getType(), + adException.getMessage() ); - sendAnomalyDetectorValidationParseResponse(issue, channel); - return; + validateAction.sendValidationParseResponse(issue, channel); } else { throw ex; } } - ValidateAnomalyDetectorRequest validateAnomalyDetectorRequest = new ValidateAnomalyDetectorRequest( - detector, - typesStr, - maxSingleEntityDetectors, - maxMultiEntityDetectors, - maxAnomalyFeatures, - requestTimeout - ); - client.execute(ValidateAnomalyDetectorAction.INSTANCE, validateAnomalyDetectorRequest, new RestToXContentListener<>(channel)); }; } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java new file mode 100644 index 000000000..fc0eb5cfe --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest.handler; + +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; + +import java.util.List; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADProfileAction; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.ad.transport.StopDetectorAction; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.transport.TransportService; + +public class ADIndexJobActionHandler extends + IndexJobActionHandler { + + public ADIndexJobActionHandler( + Client client, + ADIndexManagement indexManagement, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder, + NodeStateManager nodeStateManager, + Settings settings + ) { + super( + client, + indexManagement, + xContentRegistry, + adTaskManager, + recorder, + AnomalyResultAction.INSTANCE, + AnalysisType.AD, + DETECTION_STATE_INDEX, + StopDetectorAction.INSTANCE, + nodeStateManager, + settings, + AD_REQUEST_TIMEOUT + ); + } + + @Override + protected ResultRequest createResultRequest(String configID, long start, long end) { + return new AnomalyResultRequest(configID, start, end); + } + + @Override + protected List getBatchConfigTaskTypes() { + return HISTORICAL_DETECTOR_TASK_TYPES; + } + + /** + * Stop config. + * For realtime config, will set job as disabled. + * For historical config, will set its task as cancelled. + * + * @param configId config id + * @param historical stop historical analysis or not + * @param user user + * @param transportService transport service + * @param listener action listener + */ + @Override + public void stopConfig( + String configId, + boolean historical, + User user, + TransportService transportService, + ActionListener listener + ) { + // make sure detector exists + nodeStateManager.getConfig(configId, AnalysisType.AD, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + if (historical) { + // stop historical analyis + taskManager + .getAndExecuteOnLatestConfigLevelTask( + configId, + getBatchConfigTaskTypes(), + (task) -> taskManager.stopHistoricalAnalysis(configId, task, user, listener), + transportService, + true,// reset task state when stop config + listener + ); + } else { + // stop realtime detector job + stopJob(configId, transportService, listener); + } + }, listener); + } + +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ADModelValidationActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ADModelValidationActionHandler.java new file mode 100644 index 000000000..78a1dfbe9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/ADModelValidationActionHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest.handler; + +import java.time.Clock; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.rest.handler.ModelValidationActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ADModelValidationActionHandler extends ModelValidationActionHandler { + + public ADModelValidationActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + AnomalyDetector config, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings, + User user + ) { + super( + clusterService, + client, + clientUtil, + listener, + config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user, + AnalysisType.AD + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index 614d47bee..63d1d0e26 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -12,80 +12,46 @@ package org.opensearch.ad.rest.handler; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.util.ParseUtils.listEqualsWithoutConsideringOrder; -import static org.opensearch.timeseries.util.ParseUtils.parseAggregators; -import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; -import static org.opensearch.timeseries.util.RestHandlerUtils.isExceptionCausedByInvalidQuery; import java.io.IOException; import java.time.Clock; import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; -import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsAction; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.IndicesOptions; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; -import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.rest.RestRequest; -import org.opensearch.search.aggregations.AggregatorFactories; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.model.Feature; -import org.opensearch.timeseries.model.MergeableList; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; -import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -116,48 +82,17 @@ * instantiate the ModelValidationActionHandler class and run the non-blocker validation logic

* */ -public abstract class AbstractAnomalyDetectorActionHandler { - public static final String EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG = "Can't create more than %d multi-entity anomaly detectors."; - public static final String EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG = - "Can't create more than %d single-entity anomaly detectors."; - public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document is found in the indices: "; - public static final String ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG = "We can have only one categorical field."; - public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; - public static final String CATEGORY_NOT_FOUND_ERR_MSG = "Can't find the categorical field %s"; - public static final String DUPLICATE_DETECTOR_MSG = "Cannot create anomaly detector with name [%s] as it's already used by detector %s"; - public static final String NAME_REGEX = "[a-zA-Z0-9._-]+"; - public static final Integer MAX_DETECTOR_NAME_SIZE = 64; - private static final Set DEFAULT_VALIDATION_ASPECTS = Sets.newHashSet(ValidationAspect.DETECTOR); - - public static String INVALID_NAME_SIZE = "Name should be shortened. The maximum limit is " + MAX_DETECTOR_NAME_SIZE + " characters."; - - protected final ADIndexManagement anomalyDetectionIndices; - protected final String detectorId; - protected final Long seqNo; - protected final Long primaryTerm; - protected final WriteRequest.RefreshPolicy refreshPolicy; - protected final AnomalyDetector anomalyDetector; - protected final ClusterService clusterService; - +public abstract class AbstractAnomalyDetectorActionHandler extends + AbstractTimeSeriesActionHandler { protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); - protected final TimeValue requestTimeout; - protected final Integer maxSingleEntityAnomalyDetectors; - protected final Integer maxMultiEntityAnomalyDetectors; - protected final Integer maxAnomalyFeatures; - protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); - protected final RestRequest.Method method; - protected final Client client; - protected final SecurityClientUtil clientUtil; - protected final TransportService transportService; - protected final NamedXContentRegistry xContentRegistry; - protected final ActionListener listener; - protected final User user; - protected final ADTaskManager adTaskManager; - protected final SearchFeatureDao searchFeatureDao; - protected final boolean isDryRun; - protected final Clock clock; - protected final String validationType; - protected final Settings settings; + + public static final String EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG = "Can't create more than %d HC anomaly detectors."; + public static final String EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG = + "Can't create more than %d single-stream anomaly detectors."; + public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document is found in the indices: "; + public static final String DUPLICATE_DETECTOR_MSG = + "Cannot create anomaly detector with name [%s] as it's already used by another detector"; + public static final String VALIDATION_FEATURE_FAILURE = "Validation failed for feature(s) of detector %s"; /** * Constructor function. @@ -166,7 +101,6 @@ public abstract class AbstractAnomalyDetectorActionHandler listener, ADIndexManagement anomalyDetectionIndices, String detectorId, Long seqNo, Long primaryTerm, WriteRequest.RefreshPolicy refreshPolicy, - AnomalyDetector anomalyDetector, + Config anomalyDetector, TimeValue requestTimeout, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, + Integer maxSingleStreamAnomalyDetectors, + Integer maxHCAnomalyDetectors, + Integer maxFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, @@ -213,746 +148,132 @@ public AbstractAnomalyDetectorActionHandler( Clock clock, Settings settings ) { - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; - this.transportService = transportService; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.listener = listener; - this.detectorId = detectorId; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.refreshPolicy = refreshPolicy; - this.anomalyDetector = anomalyDetector; - this.requestTimeout = requestTimeout; - this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; - this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; - this.maxAnomalyFeatures = maxAnomalyFeatures; - this.method = method; - this.xContentRegistry = xContentRegistry; - this.user = user; - this.adTaskManager = adTaskManager; - this.searchFeatureDao = searchFeatureDao; - this.validationType = validationType; - this.isDryRun = isDryRun; - this.clock = clock; - this.settings = settings; - } - - /** - * Start function to process create/update/validate anomaly detector request. - * If detector is not using custom result index, check if anomaly detector - * index exist first, if not, will create first. Otherwise, check if custom - * result index exists or not. If exists, will check if index mapping matches - * AD result index mapping and if user has correct permission to write index. - * If doesn't exist, will create custom result index with AD result index - * mapping. - */ - public void start() { - String resultIndex = anomalyDetector.getCustomResultIndex(); - // use default detector result index which is system index - if (resultIndex == null) { - createOrUpdateDetector(); - return; - } - - if (this.isDryRun) { - if (anomalyDetectionIndices.doesIndexExist(resultIndex)) { - anomalyDetectionIndices - .validateCustomResultIndexAndExecute( - resultIndex, - () -> createOrUpdateDetector(), - ActionListener.wrap(r -> createOrUpdateDetector(), ex -> { - logger.error(ex); - listener - .onFailure( - new ValidationException(ex.getMessage(), ValidationIssueType.RESULT_INDEX, ValidationAspect.DETECTOR) - ); - return; - }) - ); - return; - } else { - createOrUpdateDetector(); - return; - } - } - // use custom result index if not validating and resultIndex not null - anomalyDetectionIndices.initCustomResultIndexAndExecute(resultIndex, () -> createOrUpdateDetector(), listener); - } - - // if isDryRun is true then this method is being executed through Validation API meaning actual - // index won't be created, only validation checks will be executed throughout the class - private void createOrUpdateDetector() { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (!anomalyDetectionIndices.doesConfigIndexExist() && !this.isDryRun) { - logger.info("AnomalyDetector Indices do not exist"); - anomalyDetectionIndices - .initConfigIndex( - ActionListener - .wrap(response -> onCreateMappingsResponse(response, false), exception -> listener.onFailure(exception)) - ); - } else { - logger.info("AnomalyDetector Indices do exist, calling prepareAnomalyDetectorIndexing"); - logger.info("DryRun variable " + this.isDryRun); - validateDetectorName(this.isDryRun); - } - } catch (Exception e) { - logger.error("Failed to create or update detector " + detectorId, e); - listener.onFailure(e); - } - } - - // These validation checks are executed here and not in AnomalyDetector.parse() - // in order to not break any past detectors that were made with invalid names - // because it was never check on the backend in the past - protected void validateDetectorName(boolean indexingDryRun) { - if (!anomalyDetector.getName().matches(NAME_REGEX)) { - listener.onFailure(new ValidationException(CommonMessages.INVALID_NAME, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - return; - - } - if (anomalyDetector.getName().length() > MAX_DETECTOR_NAME_SIZE) { - listener.onFailure(new ValidationException(INVALID_NAME_SIZE, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - return; - } - validateTimeField(indexingDryRun); - } - - protected void validateTimeField(boolean indexingDryRun) { - String givenTimeField = anomalyDetector.getTimeField(); - GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); - getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(givenTimeField); - getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); - - // comments explaining fieldMappingResponse parsing can be found inside following method: - // AbstractAnomalyDetectorActionHandler.validateCategoricalField(String, boolean) - ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { - boolean foundField = false; - Map> mappingsByIndex = getMappingsResponse.mappings(); - - for (Map mappingsByField : mappingsByIndex.values()) { - for (Map.Entry field2Metadata : mappingsByField.entrySet()) { - - GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); - if (fieldMetadata != null) { - // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field - Map fieldMap = fieldMetadata.sourceAsMap(); - if (fieldMap != null) { - for (Object type : fieldMap.values()) { - if (type instanceof Map) { - foundField = true; - Map metadataMap = (Map) type; - String typeName = (String) metadataMap.get(CommonName.TYPE); - if (!typeName.equals(CommonName.DATE_TYPE)) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, givenTimeField), - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.DETECTOR - ) - ); - return; - } - } - } - } - } - } - } - if (!foundField) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, givenTimeField), - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.DETECTOR - ) - ); - return; - } - prepareAnomalyDetectorIndexing(indexingDryRun); - }, error -> { - String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); - logger.error(message, error); - listener.onFailure(new IllegalArgumentException(message)); - }); - clientUtil - .executeWithInjectedSecurity( - GetFieldMappingsAction.INSTANCE, - getMappingsRequest, - user, - client, - AnalysisType.AD, - mappingsListener - ); - } - - /** - * Prepare for indexing a new anomaly detector. - * @param indexingDryRun if this is dryrun for indexing; when validation, it is true; when create/update, it is false - */ - protected void prepareAnomalyDetectorIndexing(boolean indexingDryRun) { - if (method == RestRequest.Method.PUT) { - handler - .getDetectorJob( - clusterService, - client, - detectorId, - listener, - () -> updateAnomalyDetector(detectorId, indexingDryRun), - xContentRegistry - ); - } else { - createAnomalyDetector(indexingDryRun); - } - } - - protected void updateAnomalyDetector(String detectorId, boolean indexingDryRun) { - GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client - .get( - request, - ActionListener - .wrap( - response -> onGetAnomalyDetectorResponse(response, indexingDryRun, detectorId), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onGetAnomalyDetectorResponse(GetResponse response, boolean indexingDryRun, String detectorId) { - if (!response.isExists()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector existingDetector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - // If detector category field changed, frontend may not be able to render AD result for different detector types correctly. - // For example, if detector changed from HC to single entity detector, AD result page may show multiple anomaly - // result points on the same time point if there are multiple entities have anomaly results. - // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type - // in top N entities list. That's confusing. - // So we decide to block updating detector category field. - if (!listEqualsWithoutConsideringOrder(existingDetector.getCategoryFields(), anomalyDetector.getCategoryFields())) { - listener.onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD, RestStatus.BAD_REQUEST)); - return; - } - if (!Objects.equals(existingDetector.getCustomResultIndex(), anomalyDetector.getCustomResultIndex())) { - listener - .onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX, RestStatus.BAD_REQUEST)); - return; - } - - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, (adTask) -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - // can't update detector if there is AD task running - listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); - } else { - validateExistingDetector(existingDetector, indexingDryRun); - } - }, transportService, true, listener); - } catch (IOException e) { - String message = "Failed to parse anomaly detector " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - - } - - protected void validateExistingDetector(AnomalyDetector existingDetector, boolean indexingDryRun) { - if (!hasCategoryField(existingDetector) && hasCategoryField(this.anomalyDetector)) { - validateAgainstExistingMultiEntityAnomalyDetector(detectorId, indexingDryRun); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - } - - protected boolean hasCategoryField(AnomalyDetector detector) { - return detector.getCategoryFields() != null && !detector.getCategoryFields().isEmpty(); - } - - protected void validateAgainstExistingMultiEntityAnomalyDetector(String detectorId, boolean indexingDryRun) { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - response -> onSearchMultiEntityAdResponse(response, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - - } - - protected void createAnomalyDetector(boolean indexingDryRun) { - try { - List categoricalFields = anomalyDetector.getCategoryFields(); - if (categoricalFields != null && categoricalFields.size() > 0) { - validateAgainstExistingMultiEntityAnomalyDetector(null, indexingDryRun); - } else { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - QueryBuilder query = QueryBuilders.matchAllQuery(); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - - client - .search( - searchRequest, - ActionListener - .wrap( - response -> onSearchSingleEntityAdResponse(response, indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - searchAdInputIndices(null, indexingDryRun); - } - - } - } catch (Exception e) { - listener.onFailure(e); - } - } - - protected void onSearchSingleEntityAdResponse(SearchResponse response, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value >= maxSingleEntityAnomalyDetectors) { - String errorMsgSingleEntity = String - .format(Locale.ROOT, EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors); - logger.error(errorMsgSingleEntity); - if (indexingDryRun) { - listener - .onFailure( - new ValidationException(errorMsgSingleEntity, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR) - ); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsgSingleEntity)); - } else { - searchAdInputIndices(null, indexingDryRun); - } - } - - protected void onSearchMultiEntityAdResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) { - String errorMsg = String.format(Locale.ROOT, EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); - logger.error(errorMsg); - if (indexingDryRun) { - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsg)); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - } - - @SuppressWarnings("unchecked") - protected void validateCategoricalField(String detectorId, boolean indexingDryRun) { - List categoryField = anomalyDetector.getCategoryFields(); - - if (categoryField == null) { - searchAdInputIndices(detectorId, indexingDryRun); - return; - } - - // we only support a certain number of categorical field - // If there is more fields than required, AnomalyDetector's constructor - // throws ADValidationException before reaching this line - int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); - if (categoryField.size() > maxCategoryFields) { - listener - .onFailure( - new ValidationException( - CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - - String categoryField0 = categoryField.get(0); - - GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); - getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); - getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); - - ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { - // example getMappingsResponse: - // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', - // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} - // for nested field, it would be - // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} - boolean foundField = false; - - // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata - Map> mappingsByIndex = getMappingsResponse.mappings(); - - for (Map mappingsByField : mappingsByIndex.values()) { - for (Map.Entry field2Metadata : mappingsByField.entrySet()) { - // example output: - // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} - - // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata - - GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); - - if (fieldMetadata != null) { - // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field - Map fieldMap = fieldMetadata.sourceAsMap(); - if (fieldMap != null) { - for (Object type : fieldMap.values()) { - if (type != null && type instanceof Map) { - foundField = true; - Map metadataMap = (Map) type; - String typeName = (String) metadataMap.get(CommonName.TYPE); - if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { - listener - .onFailure( - new ValidationException( - CATEGORICAL_FIELD_TYPE_ERR_MSG, - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - } - } - } - - } - } - } - - if (foundField == false) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CATEGORY_NOT_FOUND_ERR_MSG, categoryField0), - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - - searchAdInputIndices(detectorId, indexingDryRun); - }, error -> { - String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); - logger.error(message, error); - listener.onFailure(new IllegalArgumentException(message)); - }); + super( + anomalyDetector, + anomalyDetectionIndices, + isDryRun, + client, + detectorId, + clientUtil, + user, + method, + clusterService, + xContentRegistry, + transportService, + requestTimeout, + refreshPolicy, + seqNo, + primaryTerm, + validationType, + searchFeatureDao, + maxFeatures, + maxCategoricalFields, + AnalysisType.AD, + adTaskManager, + HISTORICAL_DETECTOR_TASK_TYPES, + false, + maxSingleStreamAnomalyDetectors, + maxHCAnomalyDetectors, + clock, + settings + ); - clientUtil - .executeWithInjectedSecurity( - GetFieldMappingsAction.INSTANCE, - getMappingsRequest, - user, - client, - AnalysisType.AD, - mappingsListener - ); } - protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(QueryBuilders.matchAllQuery()) - .size(0) - .timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - - ActionListener searchResponseListener = ActionListener - .wrap( - searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ); - - clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, AnalysisType.AD, searchResponseListener); + @Override + protected TimeSeriesException createValidationException(String msg, ValidationIssueType type) { + return new ValidationException(msg, type, ValidationAspect.DETECTOR); } - protected void onSearchAdInputIndicesResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value == 0) { - String errorMsg = NO_DOCS_IN_USER_INDEX_MSG + Arrays.toString(anomalyDetector.getIndices().toArray(new String[0])); - logger.error(errorMsg); - if (indexingDryRun) { - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.INDICES, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsg)); - } else { - validateAnomalyDetectorFeatures(detectorId, indexingDryRun); - } + @Override + protected AnomalyDetector parse(XContentParser parser, GetResponse response) throws IOException { + return AnomalyDetector.parse(parser, response.getId(), response.getVersion()); } - protected void checkADNameExists(String detectorId, boolean indexingDryRun) throws IOException { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - // src/main/resources/mappings/anomaly-detectors.json#L14 - boolQueryBuilder.must(QueryBuilders.termQuery("name.keyword", anomalyDetector.getName())); - if (StringUtils.isNotBlank(detectorId)) { - boolQueryBuilder.mustNot(QueryBuilders.termQuery(RestHandlerUtils._ID, detectorId)); - } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - searchResponse -> onSearchADNameResponse(searchResponse, detectorId, anomalyDetector.getName(), indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - tryIndexingAnomalyDetector(indexingDryRun); - } - + @Override + protected String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, getMaxSingleStreamConfigs()); } - protected void onSearchADNameResponse(SearchResponse response, String detectorId, String name, boolean indexingDryRun) - throws IOException { - if (response.getHits().getTotalHits().value > 0) { - String errorMsg = String - .format( - Locale.ROOT, - DUPLICATE_DETECTOR_MSG, - name, - Arrays.stream(response.getHits().getHits()).map(hit -> hit.getId()).collect(Collectors.toList()) - ); - logger.warn(errorMsg); - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - } else { - tryIndexingAnomalyDetector(indexingDryRun); - } + @Override + protected String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, getMaxHCConfigs()); } - protected void tryIndexingAnomalyDetector(boolean indexingDryRun) throws IOException { - if (!indexingDryRun) { - indexAnomalyDetector(detectorId); - } else { - finishDetectorValidationOrContinueToModelValidation(); - } + @Override + protected String getNoDocsInUserIndexErrorMsg(String suppliedIndices) { + return String.format(Locale.ROOT, NO_DOCS_IN_USER_INDEX_MSG, suppliedIndices); } - protected Set getValidationTypes(String validationType) { - if (StringUtils.isBlank(validationType)) { - return DEFAULT_VALIDATION_ASPECTS; - } else { - Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); - return ValidationAspect - .getNames(Sets.intersection(RestValidateAnomalyDetectorAction.ALL_VALIDATION_ASPECTS_STRS, typesInRequest)); - } + @Override + protected String getDuplicateConfigErrorMsg(String name) { + return String.format(Locale.ROOT, DUPLICATE_DETECTOR_MSG, name); } - protected void finishDetectorValidationOrContinueToModelValidation() { - logger.info("Skipping indexing detector. No blocking issue found so far."); - if (!getValidationTypes(validationType).contains(ValidationAspect.MODEL)) { - listener.onResponse(null); - } else { - ModelValidationActionHandler modelValidationActionHandler = new ModelValidationActionHandler( - clusterService, - client, - clientUtil, - (ActionListener) listener, - anomalyDetector, - requestTimeout, - xContentRegistry, - searchFeatureDao, - validationType, - clock, - settings, - user - ); - modelValidationActionHandler.checkIfMultiEntityDetector(); - } - } - - @SuppressWarnings("unchecked") - protected void indexAnomalyDetector(String detectorId) throws IOException { - AnomalyDetector detector = new AnomalyDetector( - anomalyDetector.getId(), - anomalyDetector.getVersion(), - anomalyDetector.getName(), - anomalyDetector.getDescription(), - anomalyDetector.getTimeField(), - anomalyDetector.getIndices(), - anomalyDetector.getFeatureAttributes(), - anomalyDetector.getFilterQuery(), - anomalyDetector.getInterval(), - anomalyDetector.getWindowDelay(), - anomalyDetector.getShingleSize(), - anomalyDetector.getUiMetadata(), - anomalyDetector.getSchemaVersion(), + @Override + protected AnomalyDetector copyConfig(User user, Config config) { + return new AnomalyDetector( + config.getId(), + config.getVersion(), + config.getName(), + config.getDescription(), + config.getTimeField(), + config.getIndices(), + config.getFeatureAttributes(), + config.getFilterQuery(), + config.getInterval(), + config.getWindowDelay(), + config.getShingleSize(), + config.getUiMetadata(), + config.getSchemaVersion(), Instant.now(), - anomalyDetector.getCategoryFields(), + config.getCategoryFields(), user, - anomalyDetector.getCustomResultIndex(), - anomalyDetector.getImputationOption() + config.getCustomResultIndex(), + config.getImputationOption(), + config.getTransformDecay() ); - IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) - .setRefreshPolicy(refreshPolicy) - .source(detector.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .timeout(requestTimeout); - if (StringUtils.isNotBlank(detectorId)) { - indexRequest.id(detectorId); - } - - client.index(indexRequest, new ActionListener() { - @Override - public void onResponse(IndexResponse indexResponse) { - String errorMsg = checkShardsFailure(indexResponse); - if (errorMsg != null) { - listener.onFailure(new OpenSearchStatusException(errorMsg, indexResponse.status())); - return; - } - listener - .onResponse( - (T) new IndexAnomalyDetectorResponse( - indexResponse.getId(), - indexResponse.getVersion(), - indexResponse.getSeqNo(), - indexResponse.getPrimaryTerm(), - detector, - RestStatus.CREATED - ) - ); - } - - @Override - public void onFailure(Exception e) { - logger.warn("Failed to update detector", e); - if (e.getMessage() != null && e.getMessage().contains("version conflict")) { - listener - .onFailure( - new IllegalArgumentException("There was a problem updating the historical detector:[" + detectorId + "]") - ); - } else { - listener.onFailure(e); - } - } - }); } - protected void onCreateMappingsResponse(CreateIndexResponse response, boolean indexingDryRun) throws IOException { - if (response.isAcknowledged()) { - logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); - prepareAnomalyDetectorIndexing(indexingDryRun); - } else { - logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); - listener - .onFailure( - new OpenSearchStatusException( - "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - } + @SuppressWarnings("unchecked") + @Override + protected T createIndexConfigResponse(IndexResponse indexResponse, Config config) { + return (T) new IndexAnomalyDetectorResponse( + indexResponse.getId(), + indexResponse.getVersion(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + (AnomalyDetector) config, + RestStatus.CREATED + ); } - protected String checkShardsFailure(IndexResponse response) { - StringBuilder failureReasons = new StringBuilder(); - if (response.getShardInfo().getFailed() > 0) { - for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { - failureReasons.append(failure); - } - return failureReasons.toString(); - } - return null; + @Override + protected Set getDefaultValidationType() { + return Sets.newHashSet(ValidationAspect.DETECTOR); } - /** - * Validate config/syntax, and runtime error of detector features - * @param detectorId detector id - * @param indexingDryRun if false, then will eventually index detector; true, skip indexing detector - * @throws IOException when fail to parse feature aggregation - */ - // TODO: move this method to util class so that it can be re-usable for more use cases - // https://github.com/opensearch-project/anomaly-detection/issues/39 - protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexingDryRun) throws IOException { - if (anomalyDetector != null - && (anomalyDetector.getFeatureAttributes() == null || anomalyDetector.getFeatureAttributes().isEmpty())) { - checkADNameExists(detectorId, indexingDryRun); - return; - } - // checking configuration/syntax error of detector features - String error = RestHandlerUtils.checkFeaturesSyntax(anomalyDetector, maxAnomalyFeatures); - if (StringUtils.isNotBlank(error)) { - if (indexingDryRun) { - listener.onFailure(new ValidationException(error, ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); - return; - } - // checking runtime error from feature query - ActionListener>> validateFeatureQueriesListener = ActionListener.wrap(response -> { - checkADNameExists(detectorId, indexingDryRun); - }, exception -> { - listener - .onFailure( - new ValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.DETECTOR) - ); - }); - MultiResponsesDelegateActionListener>> multiFeatureQueriesResponseListener = - new MultiResponsesDelegateActionListener>>( - validateFeatureQueriesListener, - anomalyDetector.getFeatureAttributes().size(), - String.format(Locale.ROOT, "Validation failed for feature(s) of detector %s", anomalyDetector.getName()), - false - ); + @Override + protected String getFeatureErrorMsg(String name) { + return String.format(Locale.ROOT, VALIDATION_FEATURE_FAILURE, name); + } - for (Feature feature : anomalyDetector.getFeatureAttributes()) { - SearchSourceBuilder ssb = new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery()); - AggregatorFactories.Builder internalAgg = parseAggregators( - feature.getAggregation().toString(), - xContentRegistry, - feature.getId() - ); - ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); - SearchRequest searchRequest = new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(ssb); - ActionListener searchResponseListener = ActionListener.wrap(response -> { - Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); - if (aggFeatureResult.isPresent()) { - multiFeatureQueriesResponseListener - .onResponse( - new MergeableList>(new ArrayList>(Arrays.asList(aggFeatureResult))) - ); - } else { - String errorMessage = CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + feature.getName(); - logger.error(errorMessage); - multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); - } - }, e -> { - String errorMessage; - if (isExceptionCausedByInvalidQuery(e)) { - errorMessage = CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + feature.getName(); - } else { - errorMessage = CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG + feature.getName(); - } - logger.error(errorMessage, e); - multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); - }); - clientUtil - .asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, AnalysisType.AD, searchResponseListener); - } + @Override + protected void validateModel(ActionListener listener) { + ADModelValidationActionHandler modelValidationActionHandler = new ADModelValidationActionHandler( + clusterService, + client, + clientUtil, + (ActionListener) listener, + (AnomalyDetector) config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user + ); + modelValidationActionHandler.start(); } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java deleted file mode 100644 index 28e68d0fb..000000000 --- a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.rest.handler; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.util.RestHandlerUtils; - -/** - * Common handler to process AD request. - */ -public class AnomalyDetectorActionHandler { - - private final Logger logger = LogManager.getLogger(AnomalyDetectorActionHandler.class); - - /** - * Get detector job for update/delete AD job. - * If AD job exist, will return error message; otherwise, execute function. - * - * @param clusterService ES cluster service - * @param client ES node client - * @param detectorId detector identifier - * @param listener Listener to send response - * @param function AD function - * @param xContentRegistry Registry which is used for XContentParser - */ - public void getDetectorJob( - ClusterService clusterService, - Client client, - String detectorId, - ActionListener listener, - ExecutorFunction function, - NamedXContentRegistry xContentRegistry - ) { - if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { - GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - client - .get( - request, - ActionListener - .wrap(response -> onGetAdJobResponseForWrite(response, listener, function, xContentRegistry), exception -> { - logger.error("Fail to get anomaly detector job: " + detectorId, exception); - listener.onFailure(exception); - }) - ); - } else { - function.execute(); - } - } - - private void onGetAdJobResponseForWrite( - GetResponse response, - ActionListener listener, - ExecutorFunction function, - NamedXContentRegistry xContentRegistry - ) { - if (response.isExists()) { - String adJobId = response.getId(); - if (adJobId != null) { - // check if AD job is running on the detector, if yes, we can't delete the detector - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job adJob = Job.parse(parser); - if (adJob.isEnabled()) { - listener.onFailure(new OpenSearchStatusException("Detector job is running: " + adJobId, RestStatus.BAD_REQUEST)); - return; - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + adJobId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.BAD_REQUEST)); - } - } - } - function.execute(); - } -} diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index bed6a7998..a600d8750 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -21,7 +21,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; import org.opensearch.timeseries.feature.SearchFeatureDao; @@ -42,7 +41,6 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * @param client ES node client that executes actions on the local node * @param clientUtil AD client util * @param transportService ES transport service - * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager * @param detectorId detector identifier * @param seqNo sequence number of last modification @@ -50,9 +48,10 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * @param refreshPolicy refresh policy * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration - * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed - * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed - * @param maxAnomalyFeatures max features allowed per detector + * @param maxSingleStreamDetectors max single-stream anomaly detectors allowed + * @param maxHCDetectors max HC detectors allowed + * @param maxFeatures max features allowed per detector + * @param maxCategoricalFields max number of categorical fields * @param method Rest Method type * @param xContentRegistry Registry which is used for XContentParser * @param user User context @@ -65,7 +64,6 @@ public IndexAnomalyDetectorActionHandler( Client client, SecurityClientUtil clientUtil, TransportService transportService, - ActionListener listener, ADIndexManagement anomalyDetectionIndices, String detectorId, Long seqNo, @@ -73,9 +71,10 @@ public IndexAnomalyDetectorActionHandler( WriteRequest.RefreshPolicy refreshPolicy, AnomalyDetector anomalyDetector, TimeValue requestTimeout, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, + Integer maxSingleStreamDetectors, + Integer maxHCDetectors, + Integer maxFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, @@ -88,7 +87,6 @@ public IndexAnomalyDetectorActionHandler( client, clientUtil, transportService, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -96,9 +94,10 @@ public IndexAnomalyDetectorActionHandler( refreshPolicy, anomalyDetector, requestTimeout, - maxSingleEntityAnomalyDetectors, - maxMultiEntityAnomalyDetectors, - maxAnomalyFeatures, + maxSingleStreamDetectors, + maxHCDetectors, + maxFeatures, + maxCategoricalFields, method, xContentRegistry, user, @@ -110,12 +109,4 @@ public IndexAnomalyDetectorActionHandler( settings ); } - - /** - * Start function to process create/update anomaly detector request. - */ - @Override - public void start() { - super.start(); - } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java deleted file mode 100644 index fd59759fc..000000000 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java +++ /dev/null @@ -1,413 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.rest.handler; - -import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ExceptionUtil.getShardsFailure; -import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; - -import java.io.IOException; -import java.time.Duration; -import java.time.Instant; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; -import org.opensearch.ad.transport.AnomalyResultAction; -import org.opensearch.ad.transport.AnomalyResultRequest; -import org.opensearch.ad.transport.StopDetectorAction; -import org.opensearch.ad.transport.StopDetectorRequest; -import org.opensearch.ad.transport.StopDetectorResponse; -import org.opensearch.client.Client; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; -import org.opensearch.jobscheduler.spi.schedule.Schedule; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.model.TaskState; -import org.opensearch.timeseries.util.RestHandlerUtils; -import org.opensearch.transport.TransportService; - -import com.google.common.base.Throwables; - -/** - * Anomaly detector job REST action handler to process POST/PUT request. - */ -public class IndexAnomalyDetectorJobActionHandler { - - private final ADIndexManagement anomalyDetectionIndices; - private final String detectorId; - private final Long seqNo; - private final Long primaryTerm; - private final Client client; - private final NamedXContentRegistry xContentRegistry; - private final TransportService transportService; - private final ADTaskManager adTaskManager; - - private final Logger logger = LogManager.getLogger(IndexAnomalyDetectorJobActionHandler.class); - private final TimeValue requestTimeout; - private final ExecuteADResultResponseRecorder recorder; - - /** - * Constructor function. - * - * @param client ES node client that executes actions on the local node - * @param anomalyDetectionIndices anomaly detector index manager - * @param detectorId detector identifier - * @param seqNo sequence number of last modification - * @param primaryTerm primary term of last modification - * @param requestTimeout request time out configuration - * @param xContentRegistry Registry which is used for XContentParser - * @param transportService transport service - * @param adTaskManager AD task manager - * @param recorder Utility to record AnomalyResultAction execution result - */ - public IndexAnomalyDetectorJobActionHandler( - Client client, - ADIndexManagement anomalyDetectionIndices, - String detectorId, - Long seqNo, - Long primaryTerm, - TimeValue requestTimeout, - NamedXContentRegistry xContentRegistry, - TransportService transportService, - ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder - ) { - this.client = client; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.detectorId = detectorId; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.requestTimeout = requestTimeout; - this.xContentRegistry = xContentRegistry; - this.transportService = transportService; - this.adTaskManager = adTaskManager; - this.recorder = recorder; - } - - /** - * Start anomaly detector job. - * 1. If job doesn't exist, create new job. - * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. - * @param detector anomaly detector - * @param listener Listener to send responses - */ - public void startAnomalyDetectorJob(AnomalyDetector detector, ActionListener listener) { - // this start listener is created & injected throughout the job handler so that whenever the job response is received, - // there's the extra step of trying to index results and update detector state with a 60s delay. - ActionListener startListener = ActionListener.wrap(r -> { - try { - Instant executionEndTime = Instant.now(); - IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) detector.getInterval(); - Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); - AnomalyResultRequest getRequest = new AnomalyResultRequest( - detector.getId(), - executionStartTime.toEpochMilli(), - executionEndTime.toEpochMilli() - ); - client - .execute( - AnomalyResultAction.INSTANCE, - getRequest, - ActionListener - .wrap( - response -> recorder.indexAnomalyResult(executionStartTime, executionEndTime, response, detector), - exception -> { - - recorder - .indexAnomalyResultException( - executionStartTime, - executionEndTime, - Throwables.getStackTraceAsString(exception), - null, - detector - ); - } - ) - ); - } catch (Exception ex) { - listener.onFailure(ex); - return; - } - listener.onResponse(r); - - }, listener::onFailure); - if (!anomalyDetectionIndices.doesJobIndexExist()) { - anomalyDetectionIndices.initJobIndex(ActionListener.wrap(response -> { - if (response.isAcknowledged()) { - logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); - createJob(detector, startListener); - } else { - logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); - startListener - .onFailure( - new OpenSearchStatusException( - "Created " + CommonName.CONFIG_INDEX + " with mappings call not acknowledged.", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - } - }, exception -> startListener.onFailure(exception))); - } else { - createJob(detector, startListener); - } - } - - private void createJob(AnomalyDetector detector, ActionListener listener) { - try { - IntervalTimeConfiguration interval = (IntervalTimeConfiguration) detector.getInterval(); - Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); - Duration duration = Duration.of(interval.getInterval(), interval.getUnit()); - - Job job = new Job( - detector.getId(), - schedule, - detector.getWindowDelay(), - true, - Instant.now(), - null, - Instant.now(), - duration.getSeconds(), - detector.getUser(), - detector.getCustomResultIndex() - ); - - getJobForWrite(detector, job, listener); - } catch (Exception e) { - String message = "Failed to parse anomaly detector job " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - private void getJobForWrite(AnomalyDetector detector, Job job, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - - client - .get( - getRequest, - ActionListener - .wrap( - response -> onGetAnomalyDetectorJobForWrite(response, detector, job, listener), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onGetAnomalyDetectorJobForWrite( - GetResponse response, - AnomalyDetector detector, - Job job, - ActionListener listener - ) throws IOException { - if (response.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job currentAdJob = Job.parse(parser); - if (currentAdJob.isEnabled()) { - listener - .onFailure(new OpenSearchStatusException("Anomaly detector job is already running: " + detectorId, RestStatus.OK)); - return; - } else { - Job newJob = new Job( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - job.isEnabled(), - Instant.now(), - currentAdJob.getDisabledTime(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - // Get latest realtime task and check its state before index job. Will reset running realtime task - // as STOPPED first if job disabled, then start new job and create new realtime task. - adTaskManager.startDetector(detector, null, job.getUser(), transportService, ActionListener.wrap(r -> { - indexAnomalyDetectorJob(newJob, null, listener); - }, e -> { - // Have logged error message in ADTaskManager#startDetector - listener.onFailure(e); - })); - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + job.getName(); - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } else { - adTaskManager.startDetector(detector, null, job.getUser(), transportService, ActionListener.wrap(r -> { - indexAnomalyDetectorJob(job, null, listener); - }, e -> listener.onFailure(e))); - } - } - - private void indexAnomalyDetectorJob(Job job, ExecutorFunction function, ActionListener listener) - throws IOException { - IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .timeout(requestTimeout) - .id(detectorId); - client - .index( - indexRequest, - ActionListener - .wrap( - response -> onIndexAnomalyDetectorJobResponse(response, function, listener), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onIndexAnomalyDetectorJobResponse( - IndexResponse response, - ExecutorFunction function, - ActionListener listener - ) { - if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { - String errorMsg = getShardsFailure(response); - listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); - return; - } - if (function != null) { - function.execute(); - } else { - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( - response.getId(), - response.getVersion(), - response.getSeqNo(), - response.getPrimaryTerm(), - RestStatus.OK - ); - listener.onResponse(anomalyDetectorJobResponse); - } - } - - /** - * Stop anomaly detector job. - * 1.If job not exists, return error message - * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. - * - * @param detectorId detector identifier - * @param listener Listener to send responses - */ - public void stopAnomalyDetectorJob(String detectorId, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - - client.get(getRequest, ActionListener.wrap(response -> { - if (response.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - if (!job.isEnabled()) { - adTaskManager.stopLatestRealtimeTask(detectorId, TaskState.STOPPED, null, transportService, listener); - } else { - Job newJob = new Job( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - false, - job.getEnabledTime(), - Instant.now(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - indexAnomalyDetectorJob( - newJob, - () -> client - .execute( - StopDetectorAction.INSTANCE, - new StopDetectorRequest(detectorId), - stopAdDetectorListener(detectorId, listener) - ), - listener - ); - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } else { - listener.onFailure(new OpenSearchStatusException("Anomaly detector job not exist: " + detectorId, RestStatus.BAD_REQUEST)); - } - }, exception -> listener.onFailure(exception))); - } - - private ActionListener stopAdDetectorListener( - String detectorId, - ActionListener listener - ) { - return new ActionListener() { - @Override - public void onResponse(StopDetectorResponse stopDetectorResponse) { - if (stopDetectorResponse.success()) { - logger.info("AD model deleted successfully for detector {}", detectorId); - // StopDetectorTransportAction will send out DeleteModelAction which will clear all realtime cache. - // Pass null transport service to method "stopLatestRealtimeTask" to not re-clear coordinating node cache. - adTaskManager.stopLatestRealtimeTask(detectorId, TaskState.STOPPED, null, null, listener); - } else { - logger.error("Failed to delete AD model for detector {}", detectorId); - // If failed to clear all realtime cache, will try to re-clear coordinating node cache. - adTaskManager - .stopLatestRealtimeTask( - detectorId, - TaskState.FAILED, - new OpenSearchStatusException("Failed to delete AD model", RestStatus.INTERNAL_SERVER_ERROR), - transportService, - listener - ); - } - } - - @Override - public void onFailure(Exception e) { - logger.error("Failed to delete AD model for detector " + detectorId, e); - // If failed to clear all realtime cache, will try to re-clear coordinating node cache. - adTaskManager - .stopLatestRealtimeTask( - detectorId, - TaskState.FAILED, - new OpenSearchStatusException("Failed to execute stop detector action", RestStatus.INTERNAL_SERVER_ERROR), - transportService, - listener - ); - } - }; - } - -} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java deleted file mode 100644 index 1f26d6bbd..000000000 --- a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java +++ /dev/null @@ -1,845 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ad.rest.handler; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CONFIG_BUCKET_MINIMUM_SUCCESS_RATE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TIMES_DECREASING_INTERVAL; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.TOP_VALIDATE_TIMEOUT_IN_MILLIS; - -import java.io.IOException; -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.RangeQueryBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; -import org.opensearch.search.aggregations.AggregationBuilders; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.BucketOrder; -import org.opensearch.search.aggregations.PipelineAggregatorBuilders; -import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; -import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; -import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; -import org.opensearch.search.aggregations.bucket.histogram.Histogram; -import org.opensearch.search.aggregations.bucket.histogram.LongBounds; -import org.opensearch.search.aggregations.bucket.terms.Terms; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.FieldSortBuilder; -import org.opensearch.search.sort.SortOrder; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.model.Feature; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.MergeableList; -import org.opensearch.timeseries.model.TimeConfiguration; -import org.opensearch.timeseries.model.ValidationAspect; -import org.opensearch.timeseries.model.ValidationIssueType; -import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; -import org.opensearch.timeseries.util.ParseUtils; -import org.opensearch.timeseries.util.SecurityClientUtil; - -/** - *

This class executes all validation checks that are not blocking on the 'model' level. - * This mostly involves checking if the data is generally dense enough to complete model training - * which is based on if enough buckets in the last x intervals have at least 1 document present.

- *

Initially different bucket aggregations are executed with with every configuration applied and with - * different varying intervals in order to find the best interval for the data. If no interval is found with all - * configuration applied then each configuration is tested sequentially for sparsity

- */ -// TODO: Add more UT and IT -public class ModelValidationActionHandler { - protected static final String AGG_NAME_TOP = "top_agg"; - protected static final String AGGREGATION = "agg"; - protected final AnomalyDetector anomalyDetector; - protected final ClusterService clusterService; - protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); - protected final TimeValue requestTimeout; - protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); - protected final Client client; - protected final SecurityClientUtil clientUtil; - protected final NamedXContentRegistry xContentRegistry; - protected final ActionListener listener; - protected final SearchFeatureDao searchFeatureDao; - protected final Clock clock; - protected final String validationType; - protected final Settings settings; - protected final User user; - - /** - * Constructor function. - * - * @param clusterService ClusterService - * @param client ES node client that executes actions on the local node - * @param clientUtil AD client util - * @param listener ES channel used to construct bytes / builder based outputs, and send responses - * @param anomalyDetector anomaly detector instance - * @param requestTimeout request time out configuration - * @param xContentRegistry Registry which is used for XContentParser - * @param searchFeatureDao Search feature DAO - * @param validationType Specified type for validation - * @param clock clock object to know when to timeout - * @param settings Node settings - * @param user User info - */ - public ModelValidationActionHandler( - ClusterService clusterService, - Client client, - SecurityClientUtil clientUtil, - ActionListener listener, - AnomalyDetector anomalyDetector, - TimeValue requestTimeout, - NamedXContentRegistry xContentRegistry, - SearchFeatureDao searchFeatureDao, - String validationType, - Clock clock, - Settings settings, - User user - ) { - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; - this.listener = listener; - this.anomalyDetector = anomalyDetector; - this.requestTimeout = requestTimeout; - this.xContentRegistry = xContentRegistry; - this.searchFeatureDao = searchFeatureDao; - this.validationType = validationType; - this.clock = clock; - this.settings = settings; - this.user = user; - } - - // Need to first check if multi entity detector or not before doing any sort of validation. - // If detector is HCAD then we will find the top entity and treat as single entity for - // validation purposes - public void checkIfMultiEntityDetector() { - ActionListener> recommendationListener = ActionListener - .wrap(topEntity -> getLatestDateForValidation(topEntity), exception -> { - listener.onFailure(exception); - logger.error("Failed to get top entity for categorical field", exception); - }); - if (anomalyDetector.isHighCardinality()) { - getTopEntity(recommendationListener); - } else { - recommendationListener.onResponse(Collections.emptyMap()); - } - } - - // For single category HCAD, this method uses bucket aggregation and sort to get the category field - // that have the highest document count in order to use that top entity for further validation - // For multi-category HCADs we use a composite aggregation to find the top fields for the entity - // with the highest doc count. - private void getTopEntity(ActionListener> topEntityListener) { - // Look at data back to the lower bound given the max interval we recommend or one given - long maxIntervalInMinutes = Math.max(MAX_INTERVAL_REC_LENGTH_IN_MINUTES, anomalyDetector.getIntervalInMinutes()); - LongBounds timeRangeBounds = getTimeRangeBounds( - Instant.now().toEpochMilli(), - new IntervalTimeConfiguration(maxIntervalInMinutes, ChronoUnit.MINUTES) - ); - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) - .from(timeRangeBounds.getMin()) - .to(timeRangeBounds.getMax()); - AggregationBuilder bucketAggs; - Map topKeys = new HashMap<>(); - if (anomalyDetector.getCategoryFields().size() == 1) { - bucketAggs = AggregationBuilders - .terms(AGG_NAME_TOP) - .field(anomalyDetector.getCategoryFields().get(0)) - .order(BucketOrder.count(true)); - } else { - bucketAggs = AggregationBuilders - .composite( - AGG_NAME_TOP, - anomalyDetector - .getCategoryFields() - .stream() - .map(f -> new TermsValuesSourceBuilder(f).field(f)) - .collect(Collectors.toList()) - ) - .size(1000) - .subAggregation( - PipelineAggregatorBuilders - .bucketSort("bucketSort", Collections.singletonList(new FieldSortBuilder("_count").order(SortOrder.DESC))) - .size(1) - ); - } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(rangeQuery) - .aggregation(bucketAggs) - .trackTotalHits(false) - .size(0); - SearchRequest searchRequest = new SearchRequest() - .indices(anomalyDetector.getIndices().toArray(new String[0])) - .source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(response -> { - Aggregations aggs = response.getAggregations(); - if (aggs == null) { - topEntityListener.onResponse(Collections.emptyMap()); - return; - } - if (anomalyDetector.getCategoryFields().size() == 1) { - Terms entities = aggs.get(AGG_NAME_TOP); - Object key = entities - .getBuckets() - .stream() - .max(Comparator.comparingInt(entry -> (int) entry.getDocCount())) - .map(MultiBucketsAggregation.Bucket::getKeyAsString) - .orElse(null); - topKeys.put(anomalyDetector.getCategoryFields().get(0), key); - } else { - CompositeAggregation compositeAgg = aggs.get(AGG_NAME_TOP); - topKeys - .putAll( - compositeAgg - .getBuckets() - .stream() - .flatMap(bucket -> bucket.getKey().entrySet().stream()) // this would create a flattened stream of map entries - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())) - ); - } - for (Map.Entry entry : topKeys.entrySet()) { - if (entry.getValue() == null) { - topEntityListener.onResponse(Collections.emptyMap()); - return; - } - } - topEntityListener.onResponse(topKeys); - }, topEntityListener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private void getLatestDateForValidation(Map topEntity) { - ActionListener> latestTimeListener = ActionListener - .wrap(latest -> getSampleRangesForValidationChecks(latest, anomalyDetector, listener, topEntity), exception -> { - listener.onFailure(exception); - logger.error("Failed to create search request for last data point", exception); - }); - searchFeatureDao.getLatestDataTime(anomalyDetector, latestTimeListener); - } - - private void getSampleRangesForValidationChecks( - Optional latestTime, - AnomalyDetector detector, - ActionListener listener, - Map topEntity - ) { - if (!latestTime.isPresent() || latestTime.get() <= 0) { - listener - .onFailure( - new ValidationException( - ADCommonMessages.TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA, - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.MODEL - ) - ); - return; - } - long timeRangeEnd = Math.min(Instant.now().toEpochMilli(), latestTime.get()); - try { - getBucketAggregates(timeRangeEnd, listener, topEntity); - } catch (IOException e) { - listener.onFailure(new EndRunException(detector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); - } - } - - private void getBucketAggregates( - long latestTime, - ActionListener listener, - Map topEntity - ) throws IOException { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - if (anomalyDetector.isHighCardinality()) { - if (topEntity.isEmpty()) { - listener - .onFailure( - new ValidationException( - ADCommonMessages.CATEGORY_FIELD_TOO_SPARSE, - ValidationIssueType.CATEGORY, - ValidationAspect.MODEL - ) - ); - return; - } - for (Map.Entry entry : topEntity.entrySet()) { - query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); - } - } - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(query) - .aggregation(aggregation) - .size(0) - .timeout(requestTimeout); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - ActionListener intervalListener = ActionListener - .wrap(interval -> processIntervalRecommendation(interval, latestTime), exception -> { - listener.onFailure(exception); - logger.error("Failed to get interval recommendation", exception); - }); - final ActionListener searchResponseListener = - new ModelValidationActionHandler.DetectorIntervalRecommendationListener( - intervalListener, - searchRequest.source(), - (IntervalTimeConfiguration) anomalyDetector.getInterval(), - clock.millis() + TOP_VALIDATE_TIMEOUT_IN_MILLIS, - latestTime, - false, - MAX_TIMES_DECREASING_INTERVAL - ); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private double processBucketAggregationResults(Histogram buckets) { - int docCountOverOne = 0; - // For each entry - for (Histogram.Bucket entry : buckets.getBuckets()) { - if (entry.getDocCount() > 0) { - docCountOverOne++; - } - } - return (docCountOverOne / (double) getNumberOfSamples()); - } - - /** - * ActionListener class to handle execution of multiple bucket aggregations one after the other - * Bucket aggregation with different interval lengths are executed one by one to check if the data is dense enough - * We only need to execute the next query if the previous one led to data that is too sparse. - */ - class DetectorIntervalRecommendationListener implements ActionListener { - private final ActionListener intervalListener; - SearchSourceBuilder searchSourceBuilder; - IntervalTimeConfiguration detectorInterval; - private final long expirationEpochMs; - private final long latestTime; - boolean decreasingInterval; - int numTimesDecreasing; // maximum amount of times we will try decreasing interval for recommendation - - DetectorIntervalRecommendationListener( - ActionListener intervalListener, - SearchSourceBuilder searchSourceBuilder, - IntervalTimeConfiguration detectorInterval, - long expirationEpochMs, - long latestTime, - boolean decreasingInterval, - int numTimesDecreasing - ) { - this.intervalListener = intervalListener; - this.searchSourceBuilder = searchSourceBuilder; - this.detectorInterval = detectorInterval; - this.expirationEpochMs = expirationEpochMs; - this.latestTime = latestTime; - this.decreasingInterval = decreasingInterval; - this.numTimesDecreasing = numTimesDecreasing; - } - - @Override - public void onResponse(SearchResponse response) { - try { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - - long newIntervalMinute; - if (decreasingInterval) { - newIntervalMinute = (long) Math - .floor( - IntervalTimeConfiguration.getIntervalInMinute(detectorInterval) * INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER - ); - } else { - newIntervalMinute = (long) Math - .ceil( - IntervalTimeConfiguration.getIntervalInMinute(detectorInterval) * INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER - ); - } - double fullBucketRate = processBucketAggregationResults(aggregate); - // If rate is above success minimum then return interval suggestion. - if (fullBucketRate > INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { - intervalListener.onResponse(this.detectorInterval); - } else if (expirationEpochMs < clock.millis()) { - listener - .onFailure( - new ValidationException( - ADCommonMessages.TIMEOUT_ON_INTERVAL_REC, - ValidationIssueType.TIMEOUT, - ValidationAspect.MODEL - ) - ); - logger.info(ADCommonMessages.TIMEOUT_ON_INTERVAL_REC); - // keep trying higher intervals as new interval is below max, and we aren't decreasing yet - } else if (newIntervalMinute < MAX_INTERVAL_REC_LENGTH_IN_MINUTES && !decreasingInterval) { - searchWithDifferentInterval(newIntervalMinute); - // The below block is executed only the first time when new interval is above max and - // we aren't decreasing yet, at this point we will start decreasing for the first time - // if we are inside the below block - } else if (newIntervalMinute >= MAX_INTERVAL_REC_LENGTH_IN_MINUTES && !decreasingInterval) { - IntervalTimeConfiguration givenInterval = (IntervalTimeConfiguration) anomalyDetector.getInterval(); - this.detectorInterval = new IntervalTimeConfiguration( - (long) Math - .floor( - IntervalTimeConfiguration.getIntervalInMinute(givenInterval) * INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER - ), - ChronoUnit.MINUTES - ); - if (detectorInterval.getInterval() <= 0) { - intervalListener.onResponse(null); - return; - } - this.decreasingInterval = true; - this.numTimesDecreasing -= 1; - // Searching again using an updated interval - SearchSourceBuilder updatedSearchSourceBuilder = getSearchSourceBuilder( - searchSourceBuilder.query(), - getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinute, ChronoUnit.MINUTES)) - ); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - new SearchRequest() - .indices(anomalyDetector.getIndices().toArray(new String[0])) - .source(updatedSearchSourceBuilder), - client::search, - user, - client, - AnalysisType.AD, - this - ); - // In this case decreasingInterval has to be true already, so we will stop - // when the next new interval is below or equal to 0, or we have decreased up to max times - } else if (numTimesDecreasing >= 0 && newIntervalMinute > 0) { - this.numTimesDecreasing -= 1; - searchWithDifferentInterval(newIntervalMinute); - // this case means all intervals up to max interval recommendation length and down to either - // 0 or until we tried 10 lower intervals than the one given have been tried - // which further means the next step is to go through A/B validation checks - } else { - intervalListener.onResponse(null); - } - - } catch (Exception e) { - onFailure(e); - } - } - - private void searchWithDifferentInterval(long newIntervalMinuteValue) { - this.detectorInterval = new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES); - // Searching again using an updated interval - SearchSourceBuilder updatedSearchSourceBuilder = getSearchSourceBuilder( - searchSourceBuilder.query(), - getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES)) - ); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(updatedSearchSourceBuilder), - client::search, - user, - client, - AnalysisType.AD, - this - ); - } - - @Override - public void onFailure(Exception e) { - logger.error("Failed to recommend new interval", e); - listener - .onFailure( - new ValidationException( - ADCommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, - ValidationIssueType.AGGREGATION, - ValidationAspect.MODEL - ) - ); - } - } - - private void processIntervalRecommendation(IntervalTimeConfiguration interval, long latestTime) { - // if interval suggestion is null that means no interval could be found with all the configurations - // applied, our next step then is to check density just with the raw data and then add each configuration - // one at a time to try and find root cause of low density - if (interval == null) { - checkRawDataSparsity(latestTime); - } else { - if (interval.equals(anomalyDetector.getInterval())) { - logger.info("Using the current interval there is enough dense data "); - // Check if there is a window delay recommendation if everything else is successful and send exception - if (Instant.now().toEpochMilli() - latestTime > timeConfigToMilliSec(anomalyDetector.getWindowDelay())) { - sendWindowDelayRec(latestTime); - return; - } - // The rate of buckets with at least 1 doc with given interval is above the success rate - listener.onResponse(null); - return; - } - // return response with interval recommendation - listener - .onFailure( - new ValidationException( - ADCommonMessages.DETECTOR_INTERVAL_REC + interval.getInterval(), - ValidationIssueType.DETECTION_INTERVAL, - ValidationAspect.MODEL, - interval - ) - ); - } - } - - private AggregationBuilder getBucketAggregation(long latestTime, IntervalTimeConfiguration detectorInterval) { - return AggregationBuilders - .dateHistogram(AGGREGATION) - .field(anomalyDetector.getTimeField()) - .minDocCount(1) - .hardBounds(getTimeRangeBounds(latestTime, detectorInterval)) - .fixedInterval(DateHistogramInterval.minutes((int) IntervalTimeConfiguration.getIntervalInMinute(detectorInterval))); - } - - private SearchSourceBuilder getSearchSourceBuilder(QueryBuilder query, AggregationBuilder aggregation) { - return new SearchSourceBuilder().query(query).aggregation(aggregation).size(0).timeout(requestTimeout); - } - - private void checkRawDataSparsity(long latestTime) { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(aggregation).size(0).timeout(requestTimeout); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener - .wrap(response -> processRawDataResults(response, latestTime), listener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private Histogram checkBucketResultErrors(SearchResponse response) { - Aggregations aggs = response.getAggregations(); - if (aggs == null) { - // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date with - // the large amounts of changes there). For this reason I'm not throwing a SearchException but instead a validation exception - // which will be converted to validation response. - logger.warn("Unexpected null aggregation."); - listener - .onFailure( - new ValidationException( - ADCommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, - ValidationIssueType.AGGREGATION, - ValidationAspect.MODEL - ) - ); - return null; - } - Histogram aggregate = aggs.get(AGGREGATION); - if (aggregate == null) { - listener.onFailure(new IllegalArgumentException("Failed to find valid aggregation result")); - return null; - } - return aggregate; - } - - private void processRawDataResults(SearchResponse response, long latestTime) { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { - listener - .onFailure( - new ValidationException(ADCommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL) - ); - } else { - checkDataFilterSparsity(latestTime); - } - } - - private void checkDataFilterSparsity(long latestTime) { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener - .wrap(response -> processDataFilterResults(response, latestTime), listener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private void processDataFilterResults(SearchResponse response, long latestTime) { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { - listener - .onFailure( - new ValidationException( - ADCommonMessages.FILTER_QUERY_TOO_SPARSE, - ValidationIssueType.FILTER_QUERY, - ValidationAspect.MODEL - ) - ); - // blocks below are executed if data is dense enough with filter query applied. - // If HCAD then category fields will be added to bucket aggregation to see if they - // are the root cause of the issues and if not the feature queries will be checked for sparsity - } else if (anomalyDetector.isHighCardinality()) { - getTopEntityForCategoryField(latestTime); - } else { - try { - checkFeatureQueryDelegate(latestTime); - } catch (Exception ex) { - logger.error(ex); - listener.onFailure(ex); - } - } - } - - private void getTopEntityForCategoryField(long latestTime) { - ActionListener> getTopEntityListener = ActionListener - .wrap(topEntity -> checkCategoryFieldSparsity(topEntity, latestTime), exception -> { - listener.onFailure(exception); - logger.error("Failed to get top entity for categorical field", exception); - return; - }); - getTopEntity(getTopEntityListener); - } - - private void checkCategoryFieldSparsity(Map topEntity, long latestTime) { - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - for (Map.Entry entry : topEntity.entrySet()) { - query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); - } - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener - .wrap(response -> processTopEntityResults(response, latestTime), listener::onFailure); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - - private void processTopEntityResults(SearchResponse response, long latestTime) { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { - listener - .onFailure( - new ValidationException( - ADCommonMessages.CATEGORY_FIELD_TOO_SPARSE, - ValidationIssueType.CATEGORY, - ValidationAspect.MODEL - ) - ); - } else { - try { - checkFeatureQueryDelegate(latestTime); - } catch (Exception ex) { - logger.error(ex); - listener.onFailure(ex); - } - } - } - - private void checkFeatureQueryDelegate(long latestTime) throws IOException { - ActionListener> validateFeatureQueriesListener = ActionListener.wrap(response -> { - windowDelayRecommendation(latestTime); - }, exception -> { - listener - .onFailure(new ValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.MODEL)); - }); - MultiResponsesDelegateActionListener> multiFeatureQueriesResponseListener = - new MultiResponsesDelegateActionListener<>( - validateFeatureQueriesListener, - anomalyDetector.getFeatureAttributes().size(), - ADCommonMessages.FEATURE_QUERY_TOO_SPARSE, - false - ); - - for (Feature feature : anomalyDetector.getFeatureAttributes()) { - AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); - BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); - List featureFields = ParseUtils.getFieldNamesForFeature(feature, xContentRegistry); - for (String featureField : featureFields) { - query.filter(QueryBuilders.existsQuery(featureField)); - } - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])) - .source(searchSourceBuilder); - final ActionListener searchResponseListener = ActionListener.wrap(response -> { - Histogram aggregate = checkBucketResultErrors(response); - if (aggregate == null) { - return; - } - double fullBucketRate = processBucketAggregationResults(aggregate); - if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { - multiFeatureQueriesResponseListener - .onFailure( - new ValidationException( - ADCommonMessages.FEATURE_QUERY_TOO_SPARSE, - ValidationIssueType.FEATURE_ATTRIBUTES, - ValidationAspect.MODEL - ) - ); - } else { - multiFeatureQueriesResponseListener - .onResponse(new MergeableList<>(new ArrayList<>(Collections.singletonList(new double[] { fullBucketRate })))); - } - }, e -> { - logger.error(e); - multiFeatureQueriesResponseListener - .onFailure(new OpenSearchStatusException(ADCommonMessages.FEATURE_QUERY_TOO_SPARSE, RestStatus.BAD_REQUEST, e)); - }); - // using the original context in listener as user roles have no permissions for internal operations like fetching a - // checkpoint - clientUtil - .asyncRequestWithInjectedSecurity( - searchRequest, - client::search, - user, - client, - AnalysisType.AD, - searchResponseListener - ); - } - } - - private void sendWindowDelayRec(long latestTimeInMillis) { - long minutesSinceLastStamp = (long) Math.ceil((Instant.now().toEpochMilli() - latestTimeInMillis) / 60000.0); - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, ADCommonMessages.WINDOW_DELAY_REC, minutesSinceLastStamp, minutesSinceLastStamp), - ValidationIssueType.WINDOW_DELAY, - ValidationAspect.MODEL, - new IntervalTimeConfiguration(minutesSinceLastStamp, ChronoUnit.MINUTES) - ) - ); - } - - private void windowDelayRecommendation(long latestTime) { - // Check if there is a better window-delay to recommend and if one was recommended - // then send exception and return, otherwise continue to let user know data is too sparse as explained below - if (Instant.now().toEpochMilli() - latestTime > timeConfigToMilliSec(anomalyDetector.getWindowDelay())) { - sendWindowDelayRec(latestTime); - return; - } - // This case has been reached if following conditions are met: - // 1. no interval recommendation was found that leads to a bucket success rate of >= 0.75 - // 2. bucket success rate with the given interval and just raw data is also below 0.75. - // 3. no single configuration during the following checks reduced the bucket success rate below 0.25 - // This means the rate with all configs applied or just raw data was below 0.75 but the rate when checking each configuration at - // a time was always above 0.25 meaning the best suggestion is to simply ingest more data or change interval since - // we have no more insight regarding the root cause of the lower density. - listener - .onFailure(new ValidationException(ADCommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL)); - } - - private LongBounds getTimeRangeBounds(long endMillis, IntervalTimeConfiguration detectorIntervalInMinutes) { - Long detectorInterval = timeConfigToMilliSec(detectorIntervalInMinutes); - Long startMillis = endMillis - (getNumberOfSamples() * detectorInterval); - return new LongBounds(startMillis, endMillis); - } - - private int getNumberOfSamples() { - long interval = anomalyDetector.getIntervalInMilliseconds(); - return Math - .max( - (int) (Duration.ofHours(AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS).toMillis() / interval), - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES - ); - } - - private Long timeConfigToMilliSec(TimeConfiguration config) { - return Optional.ofNullable((IntervalTimeConfiguration) config).map(t -> t.toDuration().toMillis()).orElse(0L); - } -} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java index 3c0b13c5e..cf52a2237 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java @@ -14,24 +14,23 @@ import java.time.Clock; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; /** * Anomaly detector REST action handler to process POST request. * POST request is for validating anomaly detector against detector and/or model configs. */ -public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetectorActionHandler { +public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetectorActionHandler { /** * Constructor function. @@ -39,13 +38,13 @@ public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetecto * @param clusterService ClusterService * @param client ES node client that executes actions on the local node * @param clientUtil AD client utility - * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed * @param maxAnomalyFeatures max features allowed per detector + * @param maxCategoricalFields max number of categorical fields * @param method Rest Method type * @param xContentRegistry Registry which is used for XContentParser * @param user User context @@ -58,13 +57,13 @@ public ValidateAnomalyDetectorActionHandler( ClusterService clusterService, Client client, SecurityClientUtil clientUtil, - ActionListener listener, ADIndexManagement anomalyDetectionIndices, - AnomalyDetector anomalyDetector, + Config anomalyDetector, TimeValue requestTimeout, Integer maxSingleEntityAnomalyDetectors, Integer maxMultiEntityAnomalyDetectors, Integer maxAnomalyFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, @@ -78,9 +77,8 @@ public ValidateAnomalyDetectorActionHandler( client, clientUtil, null, - listener, anomalyDetectionIndices, - AnomalyDetector.NO_ID, + Config.NO_ID, null, null, null, @@ -89,6 +87,7 @@ public ValidateAnomalyDetectorActionHandler( maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry, user, @@ -100,16 +99,4 @@ public ValidateAnomalyDetectorActionHandler( settings ); } - - // If validation type is detector then all validation in AbstractAnomalyDetectorActionHandler that is called - // by super.start() involves validation checks against the detector configurations, - // any issues raised here would block user from creating the anomaly detector. - // If validation Aspect is of type model then further non-blocker validation will be executed - // after the blocker validation is executed. Any issues that are raised for model validation - // are simply warnings for the user in terms of how configuration could be changed to lead to - // a higher likelihood of model training completing successfully - @Override - public void start() { - super.start(); - } } diff --git a/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java index ed4414f6c..fc6fb2ce0 100644 --- a/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java +++ b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.settings; @@ -34,6 +28,8 @@ public class ADEnabledSetting extends DynamicNumericSetting { */ public static final String AD_ENABLED = "plugins.anomaly_detection.enabled"; + // use TimeSeriesEnabledSetting.BREAKER_ENABLED instread + @Deprecated public static final String AD_BREAKER_ENABLED = "plugins.anomaly_detection.breaker.enabled"; public static final String LEGACY_OPENDISTRO_AD_ENABLED = "opendistro.anomaly_detection.enabled"; @@ -82,7 +78,7 @@ public class ADEnabledSetting extends DynamicNumericSetting { * filter out unpopular items that are not likely to appear more * than once. Whether this bloom filter is enabled or not. */ - put(DOOR_KEEPER_IN_CACHE_ENABLED, Setting.boolSetting(DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic)); + put(DOOR_KEEPER_IN_CACHE_ENABLED, Setting.boolSetting(DOOR_KEEPER_IN_CACHE_ENABLED, true, NodeScope, Dynamic)); } }); @@ -105,14 +101,6 @@ public static boolean isADEnabled() { return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.AD_ENABLED); } - /** - * Whether AD circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause an AD job to be stopped. - * @return whether AD circuit breaker is enabled or not. - */ - public static boolean isADBreakerEnabled() { - return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.AD_BREAKER_ENABLED); - } - /** * If enabled, we use samples plus interpolation to train models. * @return wWhether interpolation in HCAD cold start is enabled or not. diff --git a/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java index e064867a0..869cdf412 100644 --- a/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java +++ b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.settings; diff --git a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java index b5f10b383..3e732b374 100644 --- a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java @@ -115,7 +115,7 @@ private AnomalyDetectorSettings() {} /** * @deprecated This setting is deprecated because we need to manage fault tolerance for * multiple analysis such as AD and forecasting. - * Use TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE instead. + * Use TimeSeriesSettings#MAX_RETRY_FOR_UNRESPONSIVE_NODE instead. */ @Deprecated public static final Setting AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting @@ -130,7 +130,7 @@ private AnomalyDetectorSettings() {} /** * @deprecated This setting is deprecated because we need to manage fault tolerance for * multiple analysis such as AD and forecasting. - * Use TimeSeriesSettings.COOLDOWN_MINUTES instead. + * Use {@link TimeSeriesSettings#COOLDOWN_MINUTES} instead. */ @Deprecated public static final Setting AD_COOLDOWN_MINUTES = Setting @@ -144,7 +144,7 @@ private AnomalyDetectorSettings() {} /** * @deprecated This setting is deprecated because we need to manage fault tolerance for * multiple analysis such as AD and forecasting. - * Use TimeSeriesSettings.BACKOFF_MINUTES instead. + * Use {@link TimeSeriesSettings#BACKOFF_MINUTES} instead. */ @Deprecated public static final Setting AD_BACKOFF_MINUTES = Setting @@ -238,10 +238,6 @@ private AnomalyDetectorSettings() {} public static final int MAX_SAMPLE_STRIDE = 64; - public static final int TRAIN_SAMPLE_TIME_RANGE_IN_HOURS = 24; - - public static final int MIN_TRAIN_SAMPLES = 512; - public static final int MAX_IMPUTATION_NEIGHBOR_DISTANCE = 2; // shingling @@ -592,37 +588,6 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - /** - * EntityRequest has entityName (# category fields * 256, the recommended limit - * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest - * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, - * 8 bytes), and priority (12 bytes). - * Plus Java object size (12 bytes), we have roughly 928 bytes per request - * assuming we have 2 categorical fields (plan to support 2 categorical fields now). - * We don't want the total size exceeds 0.1% of the heap. - * We can have at most 0.1% heap / 928 = heap / 928,000. - * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 928 = 1078 - */ - // to be replaced by TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES - @Deprecated - public static int ENTITY_REQUEST_SIZE_IN_BYTES = 928; - - /** - * EntityFeatureRequest consists of EntityRequest (928 bytes, read comments - * of ENTITY_COLD_START_QUEUE_SIZE_CONSTANT), pointer to current feature - * (8 bytes), and dataStartTimeMillis (8 bytes). We have roughly - * 928 + 16 = 944 bytes per request. - * - * We don't want the total size exceeds 0.1% of the heap. - * We should have at most 0.1% heap / 944 = heap / 944,000 - * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 944 = 1059 - */ - // to be replaced by TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES - @Deprecated - public static int ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES = 944; - // ====================================== // pagination setting // ====================================== @@ -701,14 +666,6 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - // ====================================== - // Validate Detector API setting - // ====================================== - public static final long TOP_VALIDATE_TIMEOUT_IN_MILLIS = 10_000; - public static final long MAX_INTERVAL_REC_LENGTH_IN_MINUTES = 60L; - public static final double INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER = 1.2; - public static final double INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER = 0.8; - public static final double INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE = 0.75; public static final double CONFIG_BUCKET_MINIMUM_SUCCESS_RATE = 0.25; // This value is set to decrease the number of times we decrease the interval when recommending a new one // The reason we need a max is because user could give an arbitrarly large interval where we don't know even diff --git a/src/main/java/org/opensearch/ad/stats/ADStats.java b/src/main/java/org/opensearch/ad/stats/ADStats.java index 1fb0e8fe4..433b8b0aa 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStats.java +++ b/src/main/java/org/opensearch/ad/stats/ADStats.java @@ -1,84 +1,19 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.stats; -import java.util.HashMap; import java.util.Map; -/** - * This class is the main entry-point for access to the stats that the AD plugin keeps track of. - */ -public class ADStats { - - private Map> stats; - - /** - * Constructor - * - * @param stats Map of the stats that are to be kept - */ - public ADStats(Map> stats) { - this.stats = stats; - } +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.stats.TimeSeriesStat; - /** - * Get the stats - * - * @return all of the stats - */ - public Map> getStats() { - return stats; - } +public class ADStats extends Stats { - /** - * Get individual stat by stat name - * - * @param key Name of stat - * @return ADStat - * @throws IllegalArgumentException thrown on illegal statName - */ - public ADStat getStat(String key) throws IllegalArgumentException { - if (!stats.keySet().contains(key)) { - throw new IllegalArgumentException("Stat=\"" + key + "\" does not exist"); - } - return stats.get(key); + public ADStats(Map> stats) { + super(stats); } - /** - * Get a map of the stats that are kept at the node level - * - * @return Map of stats kept at the node level - */ - public Map> getNodeStats() { - return getClusterOrNodeStats(false); - } - - /** - * Get a map of the stats that are kept at the cluster level - * - * @return Map of stats kept at the cluster level - */ - public Map> getClusterStats() { - return getClusterOrNodeStats(true); - } - - private Map> getClusterOrNodeStats(Boolean getClusterStats) { - Map> statsMap = new HashMap<>(); - - for (Map.Entry> entry : stats.entrySet()) { - if (entry.getValue().isClusterLevel() == getClusterStats) { - statsMap.put(entry.getKey(), entry.getValue()); - } - } - return statsMap; - } } diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeCountSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeCountSupplier.java new file mode 100644 index 000000000..48cc36ebb --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeCountSupplier.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.stats.suppliers; + +import java.util.function.Supplier; +import java.util.stream.Stream; + +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADModelManager; + +/** + * ModelsOnNodeCountSupplier provides the number of models a node contains + */ +public class ADModelsOnNodeCountSupplier implements Supplier { + private ADModelManager modelManager; + private ADCacheProvider adCache; + + /** + * Constructor + * + * @param modelManager object that manages the model partitions hosted on the node + * @param adCache object that manages multi-entity detectors' models + */ + public ADModelsOnNodeCountSupplier(ADModelManager modelManager, ADCacheProvider adCache) { + this.modelManager = modelManager; + this.adCache = adCache; + } + + @Override + public Long get() { + return Stream.concat(modelManager.getAllModels().stream(), adCache.get().getAllModels().stream()).count(); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeSupplier.java new file mode 100644 index 000000000..26b1cb8d5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/ADModelsOnNodeSupplier.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.stats.suppliers; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.timeseries.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.LAST_USED_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.MODEL_TYPE_KEY; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.constant.CommonName; + +public class ADModelsOnNodeSupplier implements Supplier>> { + private ADModelManager modelManager; + private ADCacheProvider adCache; + // the max number of models to return per node. Defaults to 100. + private volatile int adNumModelsToReturn; + + /** + * Set that contains the model stats that should be exposed. + */ + public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( + Arrays + .asList( + CommonName.MODEL_ID_FIELD, + ADCommonName.DETECTOR_ID_KEY, + MODEL_TYPE_KEY, + CommonName.ENTITY_KEY, + LAST_USED_TIME_KEY, + LAST_CHECKPOINT_TIME_KEY + ) + ); + + /** + * Constructor + * + * @param modelManager object that manages the model partitions hosted on the node + * @param adCache object that manages multi-entity detectors' models + * @param settings node settings accessor + * @param clusterService Cluster service accessor + */ + public ADModelsOnNodeSupplier(ADModelManager modelManager, ADCacheProvider adCache, Settings settings, ClusterService clusterService) { + this.modelManager = modelManager; + this.adCache = adCache; + this.adNumModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.adNumModelsToReturn = it); + + } + + @Override + public List> get() { + Stream> adStream = Stream + .concat(modelManager.getAllModels().stream(), adCache.get().getAllModels().stream()) + .limit(adNumModelsToReturn) + .map( + modelState -> modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + + return adStream.collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java deleted file mode 100644 index 8fdac74d7..000000000 --- a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.stats.suppliers; - -import java.util.function.Supplier; -import java.util.stream.Stream; - -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.ModelManager; - -/** - * ModelsOnNodeCountSupplier provides the number of models a node contains - */ -public class ModelsOnNodeCountSupplier implements Supplier { - private ModelManager modelManager; - private CacheProvider cache; - - /** - * Constructor - * - * @param modelManager object that manages the model partitions hosted on the node - * @param cache object that manages multi-entity detectors' models - */ - public ModelsOnNodeCountSupplier(ModelManager modelManager, CacheProvider cache) { - this.modelManager = modelManager; - this.cache = cache; - } - - @Override - public Long get() { - return Stream.concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()).count(); - } -} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java deleted file mode 100644 index 2cdee5fb8..000000000 --- a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.stats.suppliers; - -import static org.opensearch.ad.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; -import static org.opensearch.ad.ml.ModelState.LAST_USED_TIME_KEY; -import static org.opensearch.ad.ml.ModelState.MODEL_TYPE_KEY; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.timeseries.constant.CommonName; - -/** - * ModelsOnNodeSupplier provides a List of ModelStates info for the models the nodes contains - */ -public class ModelsOnNodeSupplier implements Supplier>> { - private ModelManager modelManager; - private CacheProvider cache; - // the max number of models to return per node. Defaults to 100. - private volatile int numModelsToReturn; - - /** - * Set that contains the model stats that should be exposed. - */ - public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( - Arrays - .asList( - CommonName.MODEL_ID_FIELD, - ADCommonName.DETECTOR_ID_KEY, - MODEL_TYPE_KEY, - CommonName.ENTITY_KEY, - LAST_USED_TIME_KEY, - LAST_CHECKPOINT_TIME_KEY - ) - ); - - /** - * Constructor - * - * @param modelManager object that manages the model partitions hosted on the node - * @param cache object that manages multi-entity detectors' models - * @param settings node settings accessor - * @param clusterService Cluster service accessor - */ - public ModelsOnNodeSupplier(ModelManager modelManager, CacheProvider cache, Settings settings, ClusterService clusterService) { - this.modelManager = modelManager; - this.cache = cache; - this.numModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); - } - - @Override - public List> get() { - List> values = new ArrayList<>(); - Stream - .concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()) - .limit(numModelsToReturn) - .forEach( - modelState -> values - .add( - modelState - .getModelStateAsMap() - .entrySet() - .stream() - .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) - ) - ); - - return values; - } -} diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java index 05897fe64..f0490618e 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java @@ -11,11 +11,6 @@ package org.opensearch.ad.task; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_SAMPLES_PER_TREE; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_TREES; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.TIME_DECAY; - import java.util.ArrayDeque; import java.util.Deque; import java.util.Map; @@ -62,7 +57,7 @@ protected ADBatchTaskCache(ADTask adTask) { this.entity = adTask.getEntity(); AnomalyDetector detector = adTask.getDetector(); - int numberOfTrees = NUM_TREES; + int numberOfTrees = TimeSeriesSettings.NUM_TREES; int shingleSize = detector.getShingleSize(); this.shingle = new ArrayDeque<>(shingleSize); int dimensions = detector.getShingleSize() * detector.getEnabledFeatureIds().size(); @@ -71,10 +66,10 @@ protected ADBatchTaskCache(ADTask adTask) { .builder() .dimensions(dimensions) .numberOfTrees(numberOfTrees) - .timeDecay(TIME_DECAY) - .sampleSize(NUM_SAMPLES_PER_TREE) - .outputAfter(NUM_MIN_SAMPLES) - .initialAcceptFraction(NUM_MIN_SAMPLES * 1.0d / NUM_SAMPLES_PER_TREE) + .timeDecay(TimeSeriesSettings.TIME_DECAY) + .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) + .outputAfter(TimeSeriesSettings.NUM_MIN_SAMPLES) + .initialAcceptFraction(TimeSeriesSettings.NUM_MIN_SAMPLES * 1.0d / TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .parallelExecutionEnabled(false) .compact(true) .precision(Precision.FLOAT_32) diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index f25b09af4..fba6b0206 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -12,24 +12,16 @@ package org.opensearch.ad.task; import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; -import static org.opensearch.ad.model.ADTask.CURRENT_PIECE_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.WORKER_NODE_FIELD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS; -import static org.opensearch.ad.stats.InternalStatNames.JVM_HEAP_USAGE; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; import static org.opensearch.timeseries.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; +import static org.opensearch.timeseries.stats.InternalStatNames.JVM_HEAP_USAGE; import static org.opensearch.timeseries.stats.StatNames.AD_EXECUTING_BATCH_TASK_COUNT; -import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; import java.time.Clock; import java.time.Instant; @@ -49,14 +41,10 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.caching.PriorityTracker; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; @@ -67,10 +55,7 @@ import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; import org.opensearch.ad.transport.ADBatchAnomalyResultResponse; import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionAction; -import org.opensearch.ad.transport.ADStatsNodeResponse; import org.opensearch.ad.transport.ADStatsNodesAction; -import org.opensearch.ad.transport.ADStatsRequest; -import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -92,20 +77,29 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.PriorityTracker; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TaskCancelledException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.SecurityClientUtil; @@ -131,14 +125,14 @@ public class ADBatchTaskRunner { private final FeatureManager featureManager; private final CircuitBreakerService adCircuitBreakerService; private final ADTaskManager adTaskManager; - private final AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler; + private final ResultBulkIndexingHandler anomalyResultBulkIndexHandler; private final ADIndexManagement anomalyDetectionIndices; private final SearchFeatureDao searchFeatureDao; private final ADTaskCacheManager adTaskCacheManager; private final TransportRequestOptions option; private final HashRing hashRing; - private final ModelManager modelManager; + private final ADModelManager modelManager; private volatile Integer maxAdBatchTaskPerNode; private volatile Integer pieceSize; @@ -160,11 +154,11 @@ public ADBatchTaskRunner( ADTaskManager adTaskManager, ADIndexManagement anomalyDetectionIndices, ADStats adStats, - AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler, + ResultBulkIndexingHandler anomalyResultBulkIndexHandler, ADTaskCacheManager adTaskCacheManager, SearchFeatureDao searchFeatureDao, HashRing hashRing, - ModelManager modelManager + ADModelManager modelManager ) { this.settings = settings; this.threadPool = threadPool; @@ -267,7 +261,7 @@ private ActionListener getTopEntitiesListener( adTaskCacheManager.setTopEntityInited(detectorId); int totalEntities = adTaskCacheManager.getPendingEntityCount(detectorId); logger.info("Total top entities: {} for detector {}, task {}", totalEntities, detectorId, taskId); - hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { + hashRing.getNodesWithSameLocalVersion(dataNodes -> { int numberOfEligibleDataNodes = dataNodes.length; // maxAdBatchTaskPerNode means how many task can run on per data node, which is hard limitation per node. // maxRunningEntitiesPerDetector means how many entities can run per detector on whole cluster, which is @@ -533,7 +527,7 @@ public void forwardOrExecuteADTask( ? adTask.getParentTaskId() // For HISTORICAL_HC_ENTITY task, return its parent task id : adTask.getTaskId(); // For HISTORICAL_HC_DETECTOR task, its task id is parent task id adTaskManager - .getAndExecuteOnLatestADTask( + .getAndExecuteOnLatestConfigTask( detectorId, parentTaskId, entity, @@ -578,7 +572,7 @@ public void forwardOrExecuteADTask( .entity(entity) .parentTaskId(parentTaskId) .build(); - adTaskManager.createADTaskDirectly(adEntityTask, r -> { + adTaskManager.createTaskDirectly(adEntityTask, r -> { adEntityTask.setTaskId(r.getId()); ActionListener workerNodeResponseListener = workerNodeResponseListener( adEntityTask, @@ -595,15 +589,15 @@ public void forwardOrExecuteADTask( ); } else { Map updatedFields = new HashMap<>(); - updatedFields.put(STATE_FIELD, TaskState.INIT.name()); - updatedFields.put(INIT_PROGRESS_FIELD, 0.0f); + updatedFields.put(TimeSeriesTask.STATE_FIELD, TaskState.INIT.name()); + updatedFields.put(TimeSeriesTask.INIT_PROGRESS_FIELD, 0.0f); ActionListener workerNodeResponseListener = workerNodeResponseListener( adTask, transportService, listener ); adTaskManager - .updateADTask( + .updateTask( adTask.getTaskId(), updatedFields, ActionListener.wrap(r -> forwardOrExecuteEntityTask(adTask, transportService, workerNodeResponseListener), e -> { @@ -634,7 +628,7 @@ private ActionListener workerNodeResponseListener( ) { ActionListener actionListener = ActionListener.wrap(r -> { listener.onResponse(r); - if (adTask.isEntityTask()) { + if (adTask.isHistoricalEntityTask()) { // When reach this line, the entity task already been put into worker node's cache. // Then it's safe to move entity from temp entities queue to running entities queue. adTaskCacheManager.moveToRunningEntity(adTask.getConfigId(), adTaskManager.convertEntityToString(adTask)); @@ -704,12 +698,12 @@ private synchronized void startNewEntityTaskLane(ADTask adTask, TransportService } private void dispatchTask(ADTask adTask, ActionListener listener) { - hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { - ADStatsRequest adStatsRequest = new ADStatsRequest(dataNodes); + hashRing.getNodesWithSameLocalVersion(dataNodes -> { + StatsRequest adStatsRequest = new StatsRequest(dataNodes); adStatsRequest.addAll(ImmutableSet.of(AD_EXECUTING_BATCH_TASK_COUNT.getName(), JVM_HEAP_USAGE.getName())); client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { - List candidateNodeResponse = adStatsResponse + List candidateNodeResponse = adStatsResponse .getNodes() .stream() .filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) @@ -739,9 +733,9 @@ private void dispatchTask(ADTask adTask, ActionListener listener) listener.onFailure(new LimitExceededException(adTask.getConfigId(), errorMessage)); return; } - Optional targetNode = candidateNodeResponse + Optional targetNode = candidateNodeResponse .stream() - .sorted((ADStatsNodeResponse r1, ADStatsNodeResponse r2) -> { + .sorted((StatsNodeResponse r1, StatsNodeResponse r2) -> { int result = ((Long) r1.getStatsMap().get(AD_EXECUTING_BATCH_TASK_COUNT.getName())) .compareTo((Long) r2.getStatsMap().get(AD_EXECUTING_BATCH_TASK_COUNT.getName())); if (result == 0) { @@ -808,11 +802,11 @@ private ActionListener internalBatchTaskListener(ADTask adTask, Transpor .cleanDetectorCache( adTask, transportService, - () -> adTaskManager.updateADTask(taskId, ImmutableMap.of(STATE_FIELD, TaskState.FINISHED.name())) + () -> adTaskManager.updateTask(taskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, TaskState.FINISHED.name())) ); } else { // Set entity task as FINISHED here - adTaskManager.updateADTask(adTask.getTaskId(), ImmutableMap.of(STATE_FIELD, TaskState.FINISHED.name())); + adTaskManager.updateTask(adTask.getTaskId(), ImmutableMap.of(TimeSeriesTask.STATE_FIELD, TaskState.FINISHED.name())); adTaskManager.entityTaskDone(adTask, null, transportService); } }, e -> { @@ -845,7 +839,7 @@ private void handleException(ADTask adTask, Exception e) { adStats.getStat(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName()).increment(); } // Handle AD task exception - adTaskManager.handleADTaskException(adTask, e); + adTaskManager.handleTaskException(adTask, e); } private void executeADBatchTaskOnWorkerNode(ADTask adTask, ActionListener internalListener) { @@ -888,19 +882,19 @@ private void checkCircuitBreaker(ADTask adTask) { private void runFirstPiece(ADTask adTask, Instant executeStartTime, ActionListener internalListener) { try { adTaskManager - .updateADTask( + .updateTask( adTask.getTaskId(), ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.INIT.name(), - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, adTask.getDetectionDateRange().getStartTime().toEpochMilli(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 0.0f, - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, 0.0f, - WORKER_NODE_FIELD, + TimeSeriesTask.WORKER_NODE_FIELD, clusterService.localNode().getId() ), ActionListener.wrap(r -> { @@ -996,7 +990,7 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons dataStartTime = dataStartTime - dataStartTime % interval; dataEndTime = dataEndTime - dataEndTime % interval; logger.debug("adjusted date range: start: {}, end: {}, taskId: {}", dataStartTime, dataEndTime, taskId); - if ((dataEndTime - dataStartTime) < NUM_MIN_SAMPLES * interval) { + if ((dataEndTime - dataStartTime) < TimeSeriesSettings.NUM_MIN_SAMPLES * interval) { internalListener.onFailure(new TimeSeriesException("There is not enough data to train model").countedInStats(false)); return; } @@ -1229,10 +1223,12 @@ private void storeAnomalyResultAndRunNextPiece( false ); + String detectorId = adTask.getConfigId(); anomalyResultBulkIndexHandler - .bulkIndexAnomalyResult( + .bulk( resultIndex, anomalyResults, + detectorId, runBefore == null ? actionListener : ActionListener.runBefore(actionListener, runBefore) ); } @@ -1252,7 +1248,7 @@ private void runNextPiece( String taskState = initProgress >= 1.0f ? TaskState.RUNNING.name() : TaskState.INIT.name(); logger.debug("Init progress: {}, taskState:{}, task id: {}", initProgress, taskState, taskId); - if (initProgress >= 1.0f && adTask.isEntityTask()) { + if (initProgress >= 1.0f && adTask.isHistoricalEntityTask()) { updateDetectorLevelTaskState(detectorId, adTask.getParentTaskId(), TaskState.RUNNING.name()); } @@ -1273,17 +1269,17 @@ private void runNextPiece( float taskProgress = (float) (pieceStartTime - dataStartTime) / (dataEndTime - dataStartTime); logger.debug("Task progress: {}, task id:{}, detector id:{}", taskProgress, taskId, detectorId); adTaskManager - .updateADTask( + .updateTask( taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, taskState, - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, pieceStartTime, - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, taskProgress, - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress ), ActionListener @@ -1306,19 +1302,19 @@ private void runNextPiece( logger.info("AD task finished for detector {}, task id: {}", detectorId, taskId); adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); adTaskManager - .updateADTask( + .updateTask( taskId, ImmutableMap .of( - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, dataEndTime, - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0f, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli(), - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress, - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.FINISHED ), ActionListener.wrap(r -> internalListener.onResponse("task execution done"), e -> internalListener.onFailure(e)) @@ -1328,7 +1324,7 @@ private void runNextPiece( private void updateDetectorLevelTaskState(String detectorId, String detectorTaskId, String newState) { ExecutorFunction function = () -> adTaskManager - .updateADTask(detectorTaskId, ImmutableMap.of(STATE_FIELD, newState), ActionListener.wrap(r -> { + .updateTask(detectorTaskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, newState), ActionListener.wrap(r -> { logger.info("Updated HC detector task: {} state as: {} for detector: {}", detectorTaskId, newState, detectorId); adTaskCacheManager.updateDetectorTaskState(detectorId, detectorTaskId, newState); }, e -> { logger.error("Failed to update HC detector task: {} for detector: {}", detectorTaskId, detectorId); })); @@ -1352,7 +1348,7 @@ private float calculateInitProgress(String taskId) { if (rcf == null) { return 0.0f; } - float initProgress = (float) rcf.getTotalUpdates() / NUM_MIN_SAMPLES; + float initProgress = (float) rcf.getTotalUpdates() / TimeSeriesSettings.NUM_MIN_SAMPLES; logger.debug("RCF total updates {} for task {}", rcf.getTotalUpdates(), taskId); return initProgress > 1.0f ? 1.0f : initProgress; } @@ -1381,7 +1377,7 @@ private void checkIfADTaskCancelledAndCleanupCache(ADTask adTask) { String cancelledBy = adTaskCacheManager.getCancelledBy(taskId); adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId) - && isNullOrEmpty(adTaskCacheManager.getTasksOfDetector(detectorId))) { + && ParseUtils.isNullOrEmpty(adTaskCacheManager.getTasksOfDetector(detectorId))) { // Clean up historical task cache for HC detector on worker node if no running entity task. logger.info("All AD task cancelled, cleanup historical task cache for detector {}", detectorId); adTaskCacheManager.removeHistoricalTaskCache(detectorId); diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index 014a9f798..aab2eb652 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -142,7 +142,7 @@ public synchronized void add(ADTask adTask) { throw new DuplicateTaskException(DETECTOR_IS_RUNNING); } // It's possible that multiple entity tasks of one detector run on same data node. - if (!adTask.isEntityTask() && containsTaskOfDetector(detectorId)) { + if (!adTask.isHistoricalEntityTask() && containsTaskOfDetector(detectorId)) { throw new DuplicateTaskException(DETECTOR_IS_RUNNING); } checkRunningTaskLimit(); @@ -154,7 +154,7 @@ public synchronized void add(ADTask adTask) { ADBatchTaskCache taskCache = new ADBatchTaskCache(adTask); taskCache.getCacheMemorySize().set(neededCacheSize); batchTaskCaches.put(taskId, taskCache); - if (adTask.isEntityTask()) { + if (adTask.isHistoricalEntityTask()) { ADHCBatchTaskRunState hcBatchTaskRunState = getHCBatchTaskRunState(detectorId, adTask.getConfigLevelTaskId()); if (hcBatchTaskRunState != null) { hcBatchTaskRunState.setLastTaskRunTimeInMillis(Instant.now().toEpochMilli()); diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index 454e5cfc8..644117cfd 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -12,27 +12,13 @@ package org.opensearch.ad.task; import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; import static org.opensearch.ad.constant.ADCommonMessages.EXCEED_HISTORICAL_ANALYSIS_LIMIT; import static org.opensearch.ad.constant.ADCommonMessages.HC_DETECTOR_TASK_IS_UPDATING; import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; -import static org.opensearch.ad.model.ADTask.COORDINATING_NODE_FIELD; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.ERROR_FIELD; -import static org.opensearch.ad.model.ADTask.ESTIMATED_MINUTES_LEFT_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.LAST_UPDATE_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.STOPPED_BY_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; @@ -40,30 +26,21 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; -import static org.opensearch.ad.stats.InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT; -import static org.opensearch.ad.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; -import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD; import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; import static org.opensearch.timeseries.model.TaskType.taskTypeToString; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; -import static org.opensearch.timeseries.util.ExceptionUtil.getErrorMessage; -import static org.opensearch.timeseries.util.ExceptionUtil.getShardsFailure; -import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; -import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.stats.InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT; +import static org.opensearch.timeseries.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -80,46 +57,28 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.join.ScoreMode; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.Version; import org.opensearch.action.ActionListenerResponseHandler; -import org.opensearch.action.bulk.BulkAction; -import org.opensearch.action.bulk.BulkItemResponse; -import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.ADTaskProfileRunner; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.ADEntityTaskProfile; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; import org.opensearch.ad.transport.ADBatchAnomalyResultAction; import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; import org.opensearch.ad.transport.ADCancelTaskAction; import org.opensearch.ad.transport.ADCancelTaskRequest; -import org.opensearch.ad.transport.ADStatsNodeResponse; import org.opensearch.ad.transport.ADStatsNodesAction; -import org.opensearch.ad.transport.ADStatsRequest; -import org.opensearch.ad.transport.ADTaskProfileAction; -import org.opensearch.ad.transport.ADTaskProfileNodeResponse; -import org.opensearch.ad.transport.ADTaskProfileRequest; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.ad.transport.ForwardADTaskAction; import org.opensearch.ad.transport.ForwardADTaskRequest; import org.opensearch.client.Client; @@ -127,7 +86,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -141,34 +99,39 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.NestedQueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; -import org.opensearch.index.reindex.DeleteByQueryAction; -import org.opensearch.index.reindex.DeleteByQueryRequest; -import org.opensearch.index.reindex.UpdateByQueryAction; -import org.opensearch.index.reindex.UpdateByQueryRequest; -import org.opensearch.script.Script; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.DuplicateTaskException; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TaskCancelledException; import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.function.BiCheckedFunction; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.function.ResponseTransformer; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.EntityTaskProfile; import org.opensearch.timeseries.model.TaskState; -import org.opensearch.timeseries.task.RealtimeTaskCache; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; @@ -179,28 +142,21 @@ /** * Manage AD task. */ -public class ADTaskManager { +public class ADTaskManager extends TaskManager { public static final String AD_TASK_LEAD_NODE_MODEL_ID = "ad_task_lead_node_model_id"; public static final String AD_TASK_MAINTAINENCE_NODE_MODEL_ID = "ad_task_maintainence_node_model_id"; // HC batch task timeout after 10 minutes if no update after last known run time. public static final int HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS = 600_000; - private final Logger logger = LogManager.getLogger(this.getClass()); + public final Logger logger = LogManager.getLogger(this.getClass()); static final String STATE_INDEX_NOT_EXIST_MSG = "State index does not exist."; private final Set retryableErrors = ImmutableSet.of(EXCEED_HISTORICAL_ANALYSIS_LIMIT, NO_ELIGIBLE_NODE_TO_RUN_DETECTOR); - private final Client client; - private final ClusterService clusterService; - private final NamedXContentRegistry xContentRegistry; - private final ADIndexManagement detectionIndices; + private final DiscoveryNodeFilterer nodeFilter; - private final ADTaskCacheManager adTaskCacheManager; private final HashRing hashRing; - private volatile Integer maxOldAdTaskDocsPerDetector; private volatile Integer pieceIntervalSeconds; - private volatile boolean deleteADResultWhenDeleteDetector; + private volatile TransportRequestOptions transportRequestOptions; - private final ThreadPool threadPool; - private static int DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS = 5; private final Semaphore checkingTaskSlot; private volatile Integer maxAdBatchTaskPerNode; @@ -208,6 +164,7 @@ public class ADTaskManager { private final Semaphore scaleEntityTaskLane; private static final int SCALE_ENTITY_TASK_LANE_INTERVAL_IN_MILLIS = 10_000; // 10 seconds + private final ADTaskProfileRunner taskProfileRunner; public ADTaskManager( Settings settings, @@ -218,29 +175,38 @@ public ADTaskManager( DiscoveryNodeFilterer nodeFilter, HashRing hashRing, ADTaskCacheManager adTaskCacheManager, - ThreadPool threadPool + ThreadPool threadPool, + NodeStateManager nodeStateManager, + ADTaskProfileRunner taskProfileRunner ) { - this.client = client; - this.xContentRegistry = xContentRegistry; - this.detectionIndices = detectionIndices; + super( + adTaskCacheManager, + clusterService, + client, + DETECTION_STATE_INDEX, + ADTaskType.REALTIME_TASK_TYPES, + ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES, + Collections.emptyList(), + detectionIndices, + nodeStateManager, + AnalysisType.AD, + xContentRegistry, + DETECTOR_ID_FIELD, + MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + settings, + threadPool, + ALL_AD_RESULTS_INDEX_PATTERN, + AD_BATCH_TASK_THREAD_POOL_NAME, + DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, + TaskState.STOPPED + ); + this.nodeFilter = nodeFilter; - this.clusterService = clusterService; - this.adTaskCacheManager = adTaskCacheManager; this.hashRing = hashRing; - this.maxOldAdTaskDocsPerDetector = MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, it -> maxOldAdTaskDocsPerDetector = it); - this.pieceIntervalSeconds = BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(BATCH_TASK_PIECE_INTERVAL_SECONDS, it -> pieceIntervalSeconds = it); - this.deleteADResultWhenDeleteDetector = DELETE_AD_RESULT_WHEN_DELETE_DETECTOR.get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, it -> deleteADResultWhenDeleteDetector = it); - this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); @@ -257,83 +223,10 @@ public ADTaskManager( clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_REQUEST_TIMEOUT, it -> { transportRequestOptions = TransportRequestOptions.builder().withType(TransportRequestOptions.Type.REG).withTimeout(it).build(); }); - this.threadPool = threadPool; + this.checkingTaskSlot = new Semaphore(1); this.scaleEntityTaskLane = new Semaphore(1); - } - - /** - * Start detector. Will create schedule job for realtime detector, - * and start AD task for historical detector. - * - * @param detectorId detector id - * @param detectionDateRange historical analysis date range - * @param handler anomaly detector job action handler - * @param user user - * @param transportService transport service - * @param context thread context - * @param listener action listener - */ - public void startDetector( - String detectorId, - DateRange detectionDateRange, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ThreadContext.StoredContext context, - ActionListener listener - ) { - // upgrade index mapping of AD default indices - detectionIndices.update(); - - getDetector(detectorId, (detector) -> { - if (!detector.isPresent()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - - // Validate if detector is ready to start. Will return null if ready to start. - String errorMessage = validateDetector(detector.get()); - if (errorMessage != null) { - listener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); - return; - } - String resultIndex = detector.get().getCustomResultIndex(); - if (resultIndex == null) { - startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector); - return; - } - context.restore(); - detectionIndices - .initCustomResultIndexAndExecute( - resultIndex, - () -> startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector), - listener - ); - - }, listener); - } - - private void startRealtimeOrHistoricalDetection( - DateRange detectionDateRange, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ActionListener listener, - Optional detector - ) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (detectionDateRange == null) { - // start realtime job - handler.startAnomalyDetectorJob(detector.get(), listener); - } else { - // start historical analysis task - forwardApplyForTaskSlotsRequestToLeadNode(detector.get(), detectionDateRange, user, transportService, listener); - } - } catch (Exception e) { - logger.error("Failed to stash context", e); - listener.onFailure(e); - } + this.taskProfileRunner = taskProfileRunner; } /** @@ -344,21 +237,22 @@ private void startRealtimeOrHistoricalDetection( * 3. Then coordinating node will choose one data node with least load as work * node and dispatch historical analysis to it. * - * @param detector detector + * @param config config accessor * @param detectionDateRange detection date range * @param user user * @param transportService transport service * @param listener action listener */ - protected void forwardApplyForTaskSlotsRequestToLeadNode( - AnomalyDetector detector, + @Override + public void startHistorical( + Config config, DateRange detectionDateRange, User user, TransportService transportService, - ActionListener listener + ActionListener listener ) { ForwardADTaskRequest forwardADTaskRequest = new ForwardADTaskRequest( - detector, + (AnomalyDetector) config, detectionDateRange, user, ADTaskAction.APPLY_FOR_TASK_SLOTS @@ -369,7 +263,7 @@ protected void forwardApplyForTaskSlotsRequestToLeadNode( public void forwardScaleTaskSlotRequestToLeadNode( ADTask adTask, TransportService transportService, - ActionListener listener + ActionListener listener ) { forwardRequestToLeadNode(new ForwardADTaskRequest(adTask, ADTaskAction.CHECK_AVAILABLE_TASK_SLOTS), transportService, listener); } @@ -377,9 +271,9 @@ public void forwardScaleTaskSlotRequestToLeadNode( public void forwardRequestToLeadNode( ForwardADTaskRequest forwardADTaskRequest, TransportService transportService, - ActionListener listener + ActionListener listener ) { - hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(AD_TASK_LEAD_NODE_MODEL_ID, node -> { + hashRing.buildAndGetOwningNodeWithSameLocalVersion(AD_TASK_LEAD_NODE_MODEL_ID, node -> { if (!node.isPresent()) { listener.onFailure(new ResourceNotFoundException("Can't find AD task lead node")); return; @@ -390,7 +284,7 @@ public void forwardRequestToLeadNode( ForwardADTaskAction.NAME, forwardADTaskRequest, transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); }, listener); } @@ -411,10 +305,10 @@ public void startHistoricalAnalysis( User user, int availableTaskSlots, TransportService transportService, - ActionListener listener + ActionListener listener ) { String detectorId = detector.getId(); - hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(detectorId, owningNode -> { + hashRing.buildAndGetOwningNodeWithSameLocalVersion(detectorId, owningNode -> { if (!owningNode.isPresent()) { logger.debug("Can't find eligible node to run as AD task's coordinating node"); listener.onFailure(new OpenSearchStatusException("No eligible node to run detector", RestStatus.INTERNAL_SERVER_ERROR)); @@ -465,9 +359,9 @@ protected void forwardDetectRequestToCoordinatingNode( ADTaskAction adTaskAction, TransportService transportService, DiscoveryNode node, - ActionListener listener + ActionListener listener ) { - Version adVersion = hashRing.getAdVersion(node.getId()); + Version adVersion = hashRing.getVersion(node.getId()); transportService .sendRequest( node, @@ -476,7 +370,7 @@ protected void forwardDetectRequestToCoordinatingNode( // node, check ADTaskManager#cleanDetectorCache. new ForwardADTaskRequest(detector, detectionDateRange, user, adTaskAction, availableTaskSlots, adVersion), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -492,7 +386,7 @@ protected void forwardADTaskToCoordinatingNode( ADTask adTask, ADTaskAction adTaskAction, TransportService transportService, - ActionListener listener + ActionListener listener ) { logger.debug("Forward AD task to coordinating node, task id: {}, action: {}", adTask.getTaskId(), adTaskAction.name()); transportService @@ -501,7 +395,7 @@ protected void forwardADTaskToCoordinatingNode( ForwardADTaskAction.NAME, new ForwardADTaskRequest(adTask, adTaskAction), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -519,7 +413,7 @@ protected void forwardStaleRunningEntitiesToCoordinatingNode( ADTaskAction adTaskAction, TransportService transportService, List staleRunningEntity, - ActionListener listener + ActionListener listener ) { transportService .sendRequest( @@ -527,7 +421,7 @@ protected void forwardStaleRunningEntitiesToCoordinatingNode( ForwardADTaskAction.NAME, new ForwardADTaskRequest(adTask, adTaskAction, staleRunningEntity), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -551,7 +445,7 @@ public void checkTaskSlots( User user, ADTaskAction afterCheckAction, TransportService transportService, - ActionListener listener + ActionListener listener ) { String detectorId = detector.getId(); logger.debug("Start checking task slots for detector: {}, task action: {}", detectorId, afterCheckAction); @@ -566,19 +460,19 @@ public void checkTaskSlots( ); return; } - ActionListener wrappedActionListener = ActionListener.runAfter(listener, () -> { + ActionListener wrappedActionListener = ActionListener.runAfter(listener, () -> { checkingTaskSlot.release(1); logger.debug("Release checking task slot semaphore on lead node for detector {}", detectorId); }); - hashRing.getNodesWithSameLocalAdVersion(nodes -> { + hashRing.getNodesWithSameLocalVersion(nodes -> { int maxAdTaskSlots = nodes.length * maxAdBatchTaskPerNode; - ADStatsRequest adStatsRequest = new ADStatsRequest(nodes); + StatsRequest adStatsRequest = new StatsRequest(nodes); adStatsRequest .addAll(ImmutableSet.of(AD_USED_BATCH_TASK_SLOT_COUNT.getName(), AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName())); client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { int totalUsedTaskSlots = 0; // Total entity tasks running on worker nodes int totalAssignedTaskSlots = 0; // Total assigned task slots on coordinating nodes - for (ADStatsNodeResponse response : adStatsResponse.getNodes()) { + for (StatsNodeResponse response : adStatsResponse.getNodes()) { totalUsedTaskSlots += (int) response.getStatsMap().get(AD_USED_BATCH_TASK_SLOT_COUNT.getName()); totalAssignedTaskSlots += (int) response.getStatsMap().get(AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName()); } @@ -648,7 +542,7 @@ private void forwardToCoordinatingNode( User user, ADTaskAction targetActionOfTaskSlotChecking, TransportService transportService, - ActionListener wrappedActionListener, + ActionListener wrappedActionListener, int approvedTaskSlots ) { switch (targetActionOfTaskSlotChecking) { @@ -675,7 +569,7 @@ protected void scaleTaskLaneOnCoordinatingNode( ADTask adTask, int approvedTaskSlot, TransportService transportService, - ActionListener listener + ActionListener listener ) { DiscoveryNode coordinatingNode = getCoordinatingNode(adTask); transportService @@ -684,7 +578,7 @@ protected void scaleTaskLaneOnCoordinatingNode( ForwardADTaskAction.NAME, new ForwardADTaskRequest(adTask, approvedTaskSlot, ADTaskAction.SCALE_ENTITY_TASK_SLOTS), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -704,418 +598,32 @@ private DiscoveryNode getCoordinatingNode(ADTask adTask) { return targetNode; } - /** - * Start anomaly detector. - * For historical analysis, this method will be called on coordinating node. - * For realtime task, we won't know AD job coordinating node until AD job starts. So - * this method will be called on vanilla node. - * - * Will init task index if not exist and write new AD task to index. If task index - * exists, will check if there is task running. If no running task, reset old task - * as not latest and clean old tasks which exceeds max old task doc limitation. - * Then find out node with least load and dispatch task to that node(worker node). - * - * @param detector anomaly detector - * @param detectionDateRange detection date range - * @param user user - * @param transportService transport service - * @param listener action listener - */ - public void startDetector( - AnomalyDetector detector, - DateRange detectionDateRange, - User user, - TransportService transportService, - ActionListener listener - ) { - try { - if (detectionIndices.doesStateIndexExist()) { - // If detection index exist, check if latest AD task is running - getAndExecuteOnLatestDetectorLevelTask(detector.getId(), getADTaskTypes(detectionDateRange), (adTask) -> { - if (!adTask.isPresent() || adTask.get().isDone()) { - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); - } - }, transportService, true, listener); - } else { - // If detection index doesn't exist, create index and execute detector. - detectionIndices.initStateIndex(ActionListener.wrap(r -> { - if (r.isAcknowledged()) { - logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); - logger.warn(error); - listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, e -> { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - logger.error("Failed to init anomaly detection state index", e); - listener.onFailure(e); - } - })); - } - } catch (Exception e) { - logger.error("Failed to start detector " + detector.getId(), e); - listener.onFailure(e); - } - } - - private ADTaskType getADTaskType(AnomalyDetector detector, DateRange detectionDateRange) { + @Override + protected TaskType getTaskType(Config config, DateRange detectionDateRange, boolean runOnce) { if (detectionDateRange == null) { - return detector.isHighCardinality() ? ADTaskType.REALTIME_HC_DETECTOR : ADTaskType.REALTIME_SINGLE_ENTITY; + return config.isHighCardinality() ? ADTaskType.REALTIME_HC_DETECTOR : ADTaskType.REALTIME_SINGLE_ENTITY; } else { - return detector.isHighCardinality() ? ADTaskType.HISTORICAL_HC_DETECTOR : ADTaskType.HISTORICAL_SINGLE_ENTITY; - } - } - - private List getADTaskTypes(DateRange detectionDateRange) { - return getADTaskTypes(detectionDateRange, false); - } - - /** - * Get list of task types. - * 1. If detection date range is null, will return all realtime task types - * 2. If detection date range is not null, will return all historical detector level tasks types - * if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include - * HC entity level task type. - * @param detectionDateRange detection date range - * @param resetLatestTaskStateFlag reset latest task state or not - * @return list of AD task types - */ - private List getADTaskTypes(DateRange detectionDateRange, boolean resetLatestTaskStateFlag) { - if (detectionDateRange == null) { - return REALTIME_TASK_TYPES; - } else { - if (resetLatestTaskStateFlag) { - // return all task types include HC entity task to make sure we can reset all tasks latest flag - return ALL_HISTORICAL_TASK_TYPES; - } else { - return HISTORICAL_DETECTOR_TASK_TYPES; - } - } - } - - /** - * Stop detector. - * For realtime detector, will set detector job as disabled. - * For historical detector, will set its AD task as cancelled. - * - * @param detectorId detector id - * @param historical stop historical analysis or not - * @param handler AD job action handler - * @param user user - * @param transportService transport service - * @param listener action listener - */ - public void stopDetector( - String detectorId, - boolean historical, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ActionListener listener - ) { - getDetector(detectorId, (detector) -> { - if (!detector.isPresent()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - if (historical) { - // stop historical analyis - getAndExecuteOnLatestDetectorLevelTask( - detectorId, - HISTORICAL_DETECTOR_TASK_TYPES, - (task) -> stopHistoricalAnalysis(detectorId, task, user, listener), - transportService, - false,// don't need to reset task state when stop detector - listener - ); - } else { - // stop realtime detector job - handler.stopAnomalyDetectorJob(detectorId, listener); - } - }, listener); - } - - /** - * Get anomaly detector and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param function consumer function - * @param listener action listener - * @param action listener response type - */ - public void getDetector(String detectorId, Consumer> function, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client.get(getRequest, ActionListener.wrap(response -> { - if (!response.isExists()) { - function.accept(Optional.empty()); - return; - } - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - - function.accept(Optional.of(detector)); - } catch (Exception e) { - String message = "Failed to parse anomaly detector " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, exception -> { - logger.error("Failed to get detector " + detectorId, exception); - listener.onFailure(exception); - })); - } - - /** - * Get latest AD task and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param listener action listener - * @param action listener response type - */ - public void getAndExecuteOnLatestDetectorLevelTask( - String detectorId, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - ActionListener listener - ) { - getAndExecuteOnLatestADTask(detectorId, null, null, adTaskTypes, function, transportService, resetTaskState, listener); - } - - /** - * Get one latest AD task and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param parentTaskId parent task id - * @param entity entity value - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param listener action listener - * @param action listener response type - */ - public void getAndExecuteOnLatestADTask( - String detectorId, - String parentTaskId, - Entity entity, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - ActionListener listener - ) { - getAndExecuteOnLatestADTasks(detectorId, parentTaskId, entity, adTaskTypes, (taskList) -> { - if (taskList != null && taskList.size() > 0) { - function.accept(Optional.ofNullable(taskList.get(0))); - } else { - function.accept(Optional.empty()); - } - }, transportService, resetTaskState, 1, listener); - } - - /** - * Get latest AD tasks and execute consumer function. - * If resetTaskState is true, will collect latest task's profile data from all data nodes. If no data - * node running the latest task, will reset the task state as STOPPED; otherwise, check if there is - * any stale running entities(entity exists in coordinating node cache but no task running on worker - * node) and clean up. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param parentTaskId parent task id - * @param entity entity value - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param size return how many AD tasks - * @param listener action listener - * @param response type of action listener - */ - public void getAndExecuteOnLatestADTasks( - String detectorId, - String parentTaskId, - Entity entity, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - int size, - ActionListener listener - ) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - if (parentTaskId != null) { - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + return config.isHighCardinality() ? ADTaskType.HISTORICAL_HC_DETECTOR : ADTaskType.HISTORICAL_SINGLE_ENTITY; } - if (adTaskTypes != null && adTaskTypes.size() > 0) { - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(adTaskTypes))); - } - if (entity != null && !isNullOrEmpty(entity.getAttributes())) { - String path = "entity"; - String entityKeyFieldName = path + ".name"; - String entityValueFieldName = path + ".value"; - - for (Map.Entry attribute : entity.getAttributes().entrySet()) { - BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); - TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); - TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); - - entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); - NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); - query.filter(nestedQueryBuilder); - } - } - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.source(sourceBuilder); - searchRequest.indices(DETECTION_STATE_INDEX); - - client.search(searchRequest, ActionListener.wrap(r -> { - // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 - // getTotalHits will be null when we track_total_hits is false in the query request. - // Add more checking here to cover some unknown cases. - List adTasks = new ArrayList<>(); - if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { - // don't throw exception here as consumer functions need to handle missing task - // in different way. - function.accept(adTasks); - return; - } - - Iterator iterator = r.getHits().iterator(); - while (iterator.hasNext()) { - SearchHit searchHit = iterator.next(); - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - ADTask adTask = ADTask.parse(parser, searchHit.getId()); - adTasks.add(adTask); - } catch (Exception e) { - String message = "Failed to parse AD task for detector " + detectorId + ", task id " + searchHit.getId(); - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } - if (resetTaskState) { - resetLatestDetectorTaskState(adTasks, function, transportService, listener); - } else { - function.accept(adTasks); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - function.accept(new ArrayList<>()); - } else { - logger.error("Failed to search AD task for detector " + detectorId, e); - listener.onFailure(e); - } - })); } /** - * Reset latest detector task state. Will reset both historical and realtime tasks. - * [Important!] Make sure listener returns in function - * - * @param adTasks ad tasks - * @param function consumer function - * @param transportService transport service - * @param listener action listener - * @param response type of action listener - */ - private void resetLatestDetectorTaskState( - List adTasks, - Consumer> function, - TransportService transportService, - ActionListener listener - ) { - List runningHistoricalTasks = new ArrayList<>(); - List runningRealtimeTasks = new ArrayList<>(); - for (ADTask adTask : adTasks) { - if (!adTask.isEntityTask() && !adTask.isDone()) { - if (!adTask.isHistoricalTask()) { - // try to reset task state if realtime task is not ended - runningRealtimeTasks.add(adTask); - } else { - // try to reset task state if historical task not updated for 2 piece intervals - runningHistoricalTasks.add(adTask); - } - } - } - - resetHistoricalDetectorTaskState( - runningHistoricalTasks, - () -> resetRealtimeDetectorTaskState(runningRealtimeTasks, () -> function.accept(adTasks), transportService, listener), - transportService, - listener - ); - } - - private void resetRealtimeDetectorTaskState( - List runningRealtimeTasks, - ExecutorFunction function, - TransportService transportService, - ActionListener listener - ) { - if (isNullOrEmpty(runningRealtimeTasks)) { - function.execute(); - return; - } - ADTask adTask = runningRealtimeTasks.get(0); - String detectorId = adTask.getConfigId(); - GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - client.get(getJobRequest, ActionListener.wrap(r -> { - if (r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job job = Job.parse(parser); - if (!job.isEnabled()) { - logger.debug("AD job is disabled, reset realtime task as stopped for detector {}", detectorId); - resetTaskStateAsStopped(adTask, function, transportService, listener); - } else { - function.execute(); - } - } catch (IOException e) { - logger.error(" Failed to parse AD job " + detectorId, e); - listener.onFailure(e); - } - } else { - logger.debug("AD job is not found, reset realtime task as stopped for detector {}", detectorId); - resetTaskStateAsStopped(adTask, function, transportService, listener); - } - }, e -> { - logger.error("Fail to get AD realtime job for detector " + detectorId, e); - listener.onFailure(e); - })); - } - - private void resetHistoricalDetectorTaskState( - List runningHistoricalTasks, + * If resetTaskState is true, will collect latest task's profile data from all data nodes. If no data + * node running the latest task, will reset the task state as STOPPED; otherwise, check if there is + * any stale running entities(entity exists in coordinating node cache but no task running on worker + * node) and clean up. + */ + protected void resetHistoricalConfigTaskState( + List runningHistoricalTasks, ExecutorFunction function, TransportService transportService, ActionListener listener ) { - if (isNullOrEmpty(runningHistoricalTasks)) { + if (ParseUtils.isNullOrEmpty(runningHistoricalTasks)) { function.execute(); return; } - ADTask adTask = runningHistoricalTasks.get(0); + ADTask adTask = (ADTask) runningHistoricalTasks.get(0); // If AD task is still running, but its last updated time not refreshed for 2 piece intervals, we will get // task profile to check if it's really running. If task not running, reset state as STOPPED. // For example, ES process crashes, then all tasks running on it will stay as running. We can reset the task @@ -1126,13 +634,13 @@ private void resetHistoricalDetectorTaskState( } String taskId = adTask.getTaskId(); AnomalyDetector detector = adTask.getDetector(); - getADTaskProfile(adTask, ActionListener.wrap(taskProfile -> { + taskProfileRunner.getTaskProfile(adTask, ActionListener.wrap(taskProfile -> { boolean taskStopped = isTaskStopped(taskId, detector, taskProfile); if (taskStopped) { logger.debug("Reset task state as stopped, task id: {}", adTask.getTaskId()); if (taskProfile.getTaskId() == null // This means coordinating node doesn't have HC detector cache && detector.isHighCardinality() - && !isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { + && !ParseUtils.isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { // If coordinating node restarted, HC detector cache on it will be gone. But worker node still // runs entity tasks, we'd better stop these entity tasks to clean up resource earlier. stopHistoricalAnalysis(adTask.getConfigId(), Optional.of(adTask), null, ActionListener.wrap(r -> { @@ -1151,7 +659,8 @@ private void resetHistoricalDetectorTaskState( if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { // Check if any running entity not run on worker node. If yes, we need to remove it // and poll next entity from pending entity queue and run it. - if (!isNullOrEmpty(taskProfile.getRunningEntities()) && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { + if (!ParseUtils.isNullOrEmpty(taskProfile.getRunningEntities()) + && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { List runningTasksInCoordinatingNodeCache = new ArrayList<>(taskProfile.getRunningEntities()); List runningTasksOnWorkerNode = new ArrayList<>(); if (taskProfile.getEntityTaskProfiles() != null && taskProfile.getEntityTaskProfiles().size() > 0) { @@ -1196,8 +705,8 @@ private boolean isTaskStopped(String taskId, AnomalyDetector detector, ADTaskPro } if (detector.isHighCardinality() && taskProfile.getTotalEntitiesInited() - && isNullOrEmpty(taskProfile.getRunningEntities()) - && isNullOrEmpty(taskProfile.getEntityTaskProfiles()) + && ParseUtils.isNullOrEmpty(taskProfile.getRunningEntities()) + && ParseUtils.isNullOrEmpty(taskProfile.getEntityTaskProfiles()) && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { logger.debug("AD task not running for HC detector {}, task {}", detectorId, taskId); return true; @@ -1212,12 +721,7 @@ public boolean hcBatchTaskExpired(Long latestHCTaskRunTime) { return latestHCTaskRunTime + HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS < Instant.now().toEpochMilli(); } - private void stopHistoricalAnalysis( - String detectorId, - Optional adTask, - User user, - ActionListener listener - ) { + public void stopHistoricalAnalysis(String detectorId, Optional adTask, User user, ActionListener listener) { if (!adTask.isPresent()) { listener.onFailure(new ResourceNotFoundException(detectorId, "Detector not started")); return; @@ -1229,70 +733,27 @@ private void stopHistoricalAnalysis( } String taskId = adTask.get().getTaskId(); - DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalAdVersion(); + DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalVersion(); String userName = user == null ? null : user.getName(); ADCancelTaskRequest cancelTaskRequest = new ADCancelTaskRequest(detectorId, taskId, userName, dataNodes); client.execute(ADCancelTaskAction.INSTANCE, cancelTaskRequest, ActionListener.wrap(response -> { - listener.onResponse(new AnomalyDetectorJobResponse(taskId, 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(taskId)); }, e -> { logger.error("Failed to cancel AD task " + taskId + ", detector id: " + detectorId, e); listener.onFailure(e); })); } - private boolean lastUpdateTimeOfHistoricalTaskExpired(ADTask adTask) { + private boolean lastUpdateTimeOfHistoricalTaskExpired(TimeSeriesTask adTask) { // Wait at least 10 seconds. Piece interval seconds is dynamic setting, user could change it to a smaller value. int waitingTime = Math.max(2 * pieceIntervalSeconds, 10); return adTask.getLastUpdateTime().plus(waitingTime, ChronoUnit.SECONDS).isBefore(Instant.now()); } - private void resetTaskStateAsStopped( - ADTask adTask, - ExecutorFunction function, - TransportService transportService, - ActionListener listener - ) { - cleanDetectorCache(adTask, transportService, () -> { - String taskId = adTask.getTaskId(); - Map updatedFields = ImmutableMap.of(STATE_FIELD, TaskState.STOPPED.name()); - updateADTask(taskId, updatedFields, ActionListener.wrap(r -> { - adTask.setState(TaskState.STOPPED.name()); - if (function != null) { - function.execute(); - } - // For realtime anomaly detection, we only create detector level task, no entity level realtime task. - if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { - // Reset running entity tasks as STOPPED - resetEntityTasksAsStopped(taskId); - } - }, e -> { - logger.error("Failed to update task state as STOPPED for task " + taskId, e); - listener.onFailure(e); - })); - }, listener); - } - - private void resetEntityTasksAsStopped(String detectorTaskId) { - UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); - updateByQueryRequest.indices(DETECTION_STATE_INDEX); - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); - query.filter(new TermQueryBuilder(TASK_TYPE_FIELD, ADTaskType.HISTORICAL_HC_ENTITY.name())); - query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); - updateByQueryRequest.setQuery(query); - updateByQueryRequest.setRefresh(true); - String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", STATE_FIELD, TaskState.STOPPED.name()); - updateByQueryRequest.setScript(new Script(script)); - - client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { - List bulkFailures = r.getBulkFailures(); - if (isNullOrEmpty(bulkFailures)) { - logger.debug("Updated {} child entity tasks state for detector task {}", r.getUpdated(), detectorTaskId); - } else { - logger.error("Failed to update child entity task's state for detector task {} ", detectorTaskId); - } - }, e -> logger.error("Exception happened when update child entity task's state for detector task " + detectorTaskId, e))); + @Override + protected boolean isHistoricalHCTask(TimeSeriesTask task) { + return ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(task.getTaskType()); } /** @@ -1310,8 +771,9 @@ private void resetEntityTasksAsStopped(String detectorTaskId) { * @param listener action listener * @param response type of listener */ - public void cleanDetectorCache( - ADTask adTask, + @Override + public void cleanConfigCache( + TimeSeriesTask adTask, TransportService transportService, ExecutorFunction function, ActionListener listener @@ -1320,15 +782,12 @@ public void cleanDetectorCache( String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); try { - forwardADTaskToCoordinatingNode( - adTask, - ADTaskAction.CLEAN_CACHE, - transportService, - ActionListener.wrap(r -> { function.execute(); }, e -> { - logger.error("Failed to clear detector cache on coordinating node " + coordinatingNode, e); - listener.onFailure(e); - }) - ); + forwardADTaskToCoordinatingNode((ADTask) adTask, ADTaskAction.CLEAN_CACHE, transportService, ActionListener.wrap(r -> { + function.execute(); + }, e -> { + logger.error("Failed to clear detector cache on coordinating node " + coordinatingNode, e); + listener.onFailure(e); + })); } catch (ResourceNotFoundException e) { logger .warn( @@ -1347,161 +806,27 @@ public void cleanDetectorCache( protected void cleanDetectorCache(ADTask adTask, TransportService transportService, ExecutorFunction function) { String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); - cleanDetectorCache(adTask, transportService, function, ActionListener.wrap(r -> { + cleanConfigCache(adTask, transportService, function, ActionListener.wrap(r -> { logger.debug("Successfully cleaned cache for detector {}, task {}", detectorId, taskId); }, e -> { logger.error("Failed to clean cache for detector " + detectorId + ", task " + taskId, e); })); } - /** - * Get latest historical AD task profile. - * Will not reset task state in this method. - * - * @param detectorId detector id - * @param transportService transport service - * @param profile detector profile - * @param listener action listener - */ - public void getLatestHistoricalTaskProfile( - String detectorId, - TransportService transportService, - DetectorProfile profile, - ActionListener listener - ) { - getAndExecuteOnLatestADTask(detectorId, null, null, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { - if (adTask.isPresent()) { - getADTaskProfile(adTask.get(), ActionListener.wrap(adTaskProfile -> { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - profileBuilder.adTaskProfile(adTaskProfile); - DetectorProfile detectorProfile = profileBuilder.build(); - detectorProfile.merge(profile); - listener.onResponse(detectorProfile); - }, e -> { - logger.error("Failed to get AD task profile for task " + adTask.get().getTaskId(), e); - listener.onFailure(e); - })); - } else { - DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); - listener.onResponse(profileBuilder.build()); - } - }, transportService, false, listener); - } - - /** - * Get AD task profile. - * @param adDetectorLevelTask detector level task - * @param listener action listener - */ - private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener listener) { - String detectorId = adDetectorLevelTask.getConfigId(); - - hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> { - ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); - client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { - if (response.hasFailures()) { - listener.onFailure(response.failures().get(0)); - return; - } - - List adEntityTaskProfiles = new ArrayList<>(); - ADTaskProfile detectorTaskProfile = new ADTaskProfile(adDetectorLevelTask); - for (ADTaskProfileNodeResponse node : response.getNodes()) { - ADTaskProfile taskProfile = node.getAdTaskProfile(); - if (taskProfile != null) { - if (taskProfile.getNodeId() != null) { - // HC detector: task profile from coordinating node - // Single entity detector: task profile from worker node - detectorTaskProfile.setTaskId(taskProfile.getTaskId()); - detectorTaskProfile.setShingleSize(taskProfile.getShingleSize()); - detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates()); - detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained()); - detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize()); - detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes()); - detectorTaskProfile.setNodeId(taskProfile.getNodeId()); - detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount()); - detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots()); - detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount()); - detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount()); - detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities()); - detectorTaskProfile.setAdTaskType(taskProfile.getAdTaskType()); - } - if (taskProfile.getEntityTaskProfiles() != null) { - adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles()); - } - } - } - if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) { - detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles); - } - listener.onResponse(detectorTaskProfile); - }, e -> { - logger.error("Failed to get task profile for task " + adDetectorLevelTask.getTaskId(), e); - listener.onFailure(e); - })); - }, listener); - - } - - private String validateDetector(AnomalyDetector detector) { - String error = null; - if (detector.getFeatureAttributes().size() == 0) { - error = "Can't start detector job as no features configured"; - } else if (detector.getEnabledFeatureIds().size() == 0) { - error = "Can't start detector job as no enabled features configured"; - } - return error; - } - - private void updateLatestFlagOfOldTasksAndCreateNewTask( - AnomalyDetector detector, - DateRange detectionDateRange, - User user, - ActionListener listener - ) { - UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); - updateByQueryRequest.indices(DETECTION_STATE_INDEX); - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detector.getId())); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - // make sure we reset all latest task as false when user switch from single entity to HC, vice versa. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(getADTaskTypes(detectionDateRange, true)))); - updateByQueryRequest.setQuery(query); - updateByQueryRequest.setRefresh(true); - String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", IS_LATEST_FIELD, false); - updateByQueryRequest.setScript(new Script(script)); - - client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { - List bulkFailures = r.getBulkFailures(); - if (bulkFailures.isEmpty()) { - // Realtime AD coordinating node is chosen by job scheduler, we won't know it until realtime AD job - // runs. Just set realtime AD coordinating node as null here, and AD job runner will reset correct - // coordinating node once realtime job starts. - // For historical analysis, this method will be called on coordinating node, so we can set coordinating - // node as local node. - String coordinatingNode = detectionDateRange == null ? null : clusterService.localNode().getId(); - createNewADTask(detector, detectionDateRange, user, coordinatingNode, listener); - } else { - logger.error("Failed to update old task's state for detector: {}, response: {} ", detector.getId(), r.toString()); - listener.onFailure(bulkFailures.get(0).getCause()); - } - }, e -> { - logger.error("Failed to reset old tasks as not latest for detector " + detector.getId(), e); - listener.onFailure(e); - })); - } - - private void createNewADTask( - AnomalyDetector detector, + @Override + protected void createNewTask( + Config config, DateRange detectionDateRange, + boolean runOnce, User user, String coordinatingNode, - ActionListener listener + TaskState initialState, + ActionListener listener ) { String userName = user == null ? null : user.getName(); Instant now = Instant.now(); - String taskType = getADTaskType(detector, detectionDateRange).name(); + String taskType = getTaskType(config, detectionDateRange, runOnce).name(); ADTask adTask = new ADTask.Builder() - .configId(detector.getId()) - .detector(detector) + .configId(config.getId()) + .detector((AnomalyDetector) config) .isLatest(true) .taskType(taskType) .executionStartTime(now) @@ -1515,212 +840,74 @@ private void createNewADTask( .user(user) .build(); - createADTaskDirectly( - adTask, - r -> onIndexADTaskResponse( - r, - adTask, - (response, delegatedListener) -> cleanOldAdTaskDocs(response, adTask, delegatedListener), - listener - ), - listener - ); - } - - /** - * Create AD task directly without checking index exists of not. - * [Important!] Make sure listener returns in function - * - * @param adTask AD task - * @param function consumer function - * @param listener action listener - * @param action listener response type - */ - public void createADTaskDirectly(ADTask adTask, Consumer function, ActionListener listener) { - IndexRequest request = new IndexRequest(DETECTION_STATE_INDEX); - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - request - .source(adTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { - logger.error("Failed to create AD task for detector " + adTask.getConfigId(), e); - listener.onFailure(e); - })); - } catch (Exception e) { - logger.error("Failed to create AD task for detector " + adTask.getConfigId(), e); - listener.onFailure(e); - } - } - - private void onIndexADTaskResponse( - IndexResponse response, - ADTask adTask, - BiConsumer> function, - ActionListener listener - ) { - if (response == null || response.getResult() != CREATED) { - String errorMsg = getShardsFailure(response); - listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); - return; - } - adTask.setTaskId(response.getId()); - ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { - handleADTaskException(adTask, e); - if (e instanceof DuplicateTaskException) { - listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); - } else { - // For historical AD task, clear historical task if any other exception happened. - // For realtime AD, task cache will be inited when realtime job starts, check - // ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache for details. Here the - // realtime task cache not inited yet when create AD task, so no need to cleanup. - if (adTask.isHistoricalTask()) { - adTaskCacheManager.removeHistoricalTaskCache(adTask.getConfigId()); - } - listener.onFailure(e); - } - }); - try { - // Put detector id in cache. If detector id already in cache, will throw - // DuplicateTaskException. This is to solve race condition when user send - // multiple start request for one historical detector. - if (adTask.isHistoricalTask()) { - adTaskCacheManager.add(adTask.getConfigId(), adTask); - } - } catch (Exception e) { - delegatedListener.onFailure(e); - return; - } - if (function != null) { - function.accept(response, delegatedListener); - } - } - - private void cleanOldAdTaskDocs(IndexResponse response, ADTask adTask, ActionListener delegatedListener) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, adTask.getConfigId())); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, false)); - - if (adTask.isHistoricalTask()) { - // If historical task, only delete detector level task. It may take longer time to delete entity tasks. - // We will delete child task (entity task) of detector level task in hourly cron job. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); - } else { - // We don't have entity level task for realtime detection, so will delete all tasks. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(REALTIME_TASK_TYPES))); - } - - SearchRequest searchRequest = new SearchRequest(); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder - .query(query) - .sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC) - // Search query "from" starts from 0. - .from(maxOldAdTaskDocsPerDetector) - .size(MAX_OLD_AD_TASK_DOCS); - searchRequest.source(sourceBuilder).indices(DETECTION_STATE_INDEX); - String detectorId = adTask.getConfigId(); - - deleteTaskDocs(detectorId, searchRequest, () -> { - if (adTask.isHistoricalTask()) { - // run batch result action for historical detection - runBatchResultAction(response, adTask, delegatedListener); - } else { - // return response directly for realtime detection - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( - response.getId(), - response.getVersion(), - response.getSeqNo(), - response.getPrimaryTerm(), - RestStatus.OK - ); - delegatedListener.onResponse(anomalyDetectorJobResponse); - } - }, delegatedListener); - } - - protected void deleteTaskDocs( - String detectorId, - SearchRequest searchRequest, - ExecutorFunction function, - ActionListener listener - ) { - ActionListener searchListener = ActionListener.wrap(r -> { - Iterator iterator = r.getHits().iterator(); - if (iterator.hasNext()) { - BulkRequest bulkRequest = new BulkRequest(); - while (iterator.hasNext()) { - SearchHit searchHit = iterator.next(); - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - ADTask adTask = ADTask.parse(parser, searchHit.getId()); - logger.debug("Delete old task: {} of detector: {}", adTask.getTaskId(), adTask.getConfigId()); - bulkRequest.add(new DeleteRequest(DETECTION_STATE_INDEX).id(adTask.getTaskId())); - } catch (Exception e) { - listener.onFailure(e); - } - } - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { - logger.info("Old AD tasks deleted for detector {}", detectorId); - BulkItemResponse[] bulkItemResponses = res.getItems(); - if (bulkItemResponses != null && bulkItemResponses.length > 0) { - for (BulkItemResponse bulkItemResponse : bulkItemResponses) { - if (!bulkItemResponse.isFailed()) { - logger.debug("Add detector task into cache. Task id: {}", bulkItemResponse.getId()); - // add deleted task in cache and delete its child tasks and AD results - adTaskCacheManager.addDeletedTask(bulkItemResponse.getId()); - } - } - } - // delete child tasks and AD results of this task - cleanChildTasksAndADResultsOfDeletedTask(); - - function.execute(); - }, e -> { - logger.warn("Failed to clean AD tasks for detector " + detectorId, e); - listener.onFailure(e); - })); - } else { - function.execute(); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - function.execute(); - } else { - listener.onFailure(e); - } - }); - - client.search(searchRequest, searchListener); + createTaskDirectly( + adTask, + r -> onIndexConfigTaskResponse( + r, + adTask, + (response, delegatedListener) -> cleanOldConfigTaskDocs( + response, + adTask, + (indexResponse) -> (T) new JobResponse(indexResponse.getId()), + delegatedListener + ), + listener + ), + listener + ); } - /** - * Poll deleted detector task from cache and delete its child tasks and AD results. - */ - public void cleanChildTasksAndADResultsOfDeletedTask() { - if (!adTaskCacheManager.hasDeletedTask()) { + @Override + protected void onIndexConfigTaskResponse( + IndexResponse response, + ADTask adTask, + BiConsumer> function, + ActionListener listener + ) { + if (response == null || response.getResult() != CREATED) { + String errorMsg = ExceptionUtil.getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); return; } - threadPool.schedule(() -> { - String taskId = adTaskCacheManager.pollDeletedTask(); - if (taskId == null) { - return; + adTask.setTaskId(response.getId()); + ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + handleTaskException(adTask, e); + if (e instanceof DuplicateTaskException) { + listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); + } else { + // For historical AD task, clear historical task if any other exception happened. + // For realtime AD, task cache will be inited when realtime job starts, check + // ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache for details. Here the + // realtime task cache not inited yet when create AD task, so no need to cleanup. + if (adTask.isHistoricalTask()) { + taskCacheManager.removeHistoricalTaskCache(adTask.getConfigId()); + } + listener.onFailure(e); } - DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN); - deleteADResultsRequest.setQuery(new TermsQueryBuilder(TASK_ID_FIELD, taskId)); - client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(res -> { - logger.debug("Successfully deleted AD results of task " + taskId); - DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(DETECTION_STATE_INDEX); - deleteChildTasksRequest.setQuery(new TermsQueryBuilder(PARENT_TASK_ID_FIELD, taskId)); - - client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { - logger.debug("Successfully deleted child tasks of task " + taskId); - cleanChildTasksAndADResultsOfDeletedTask(); - }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); - }, ex -> { logger.error("Failed to delete AD results for task " + taskId, ex); })); - }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); + }); + try { + // Put config id in cache. If config id already in cache, will throw + // DuplicateTaskException. This is to solve race condition when user send + // multiple start request for one historical run. + if (adTask.isHistoricalTask()) { + taskCacheManager.add(adTask.getConfigId(), adTask); + } + } catch (Exception e) { + delegatedListener.onFailure(e); + return; + } + if (function != null) { + function.accept(response, delegatedListener); + } } - private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionListener listener) { + @Override + protected void runBatchResultAction( + IndexResponse response, + ADTask adTask, + ResponseTransformer responseTransformer, + ActionListener listener + ) { client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { String remoteOrLocal = r.isRunTaskRemotely() ? "remote" : "local"; logger @@ -1731,110 +918,9 @@ private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionL remoteOrLocal, r.getNodeId() ); - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( - response.getId(), - response.getVersion(), - response.getSeqNo(), - response.getPrimaryTerm(), - RestStatus.OK - ); - listener.onResponse(anomalyDetectorJobResponse); - }, e -> listener.onFailure(e))); - } - - /** - * Handle exceptions for AD task. Update task state and record error message. - * - * @param adTask AD task - * @param e exception - */ - public void handleADTaskException(ADTask adTask, Exception e) { - // TODO: handle timeout exception - String state = TaskState.FAILED.name(); - Map updatedFields = new HashMap<>(); - if (e instanceof DuplicateTaskException) { - // If user send multiple start detector request, we will meet race condition. - // Cache manager will put first request in cache and throw DuplicateTaskException - // for the second request. We will delete the second task. - logger - .warn( - "There is already one running task for detector, detectorId:" - + adTask.getConfigId() - + ". Will delete task " - + adTask.getTaskId() - ); - deleteADTask(adTask.getTaskId()); - return; - } - if (e instanceof TaskCancelledException) { - logger.info("AD task cancelled, taskId: {}, detectorId: {}", adTask.getTaskId(), adTask.getConfigId()); - state = TaskState.STOPPED.name(); - String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); - if (stoppedBy != null) { - updatedFields.put(STOPPED_BY_FIELD, stoppedBy); - } - } else { - logger.error("Failed to execute AD batch task, task id: " + adTask.getTaskId() + ", detector id: " + adTask.getConfigId(), e); - } - updatedFields.put(ERROR_FIELD, getErrorMessage(e)); - updatedFields.put(STATE_FIELD, state); - updatedFields.put(EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli()); - updateADTask(adTask.getTaskId(), updatedFields); - } - /** - * Update AD task with specific fields. - * - * @param taskId AD task id - * @param updatedFields updated fields, key: filed name, value: new value - */ - public void updateADTask(String taskId, Map updatedFields) { - updateADTask(taskId, updatedFields, ActionListener.wrap(response -> { - if (response.status() == RestStatus.OK) { - logger.debug("Updated AD task successfully: {}, task id: {}", response.status(), taskId); - } else { - logger.error("Failed to update AD task {}, status: {}", taskId, response.status()); - } - }, e -> { logger.error("Failed to update task: " + taskId, e); })); - } - - /** - * Update AD task for specific fields. - * - * @param taskId task id - * @param updatedFields updated fields, key: filed name, value: new value - * @param listener action listener - */ - public void updateADTask(String taskId, Map updatedFields, ActionListener listener) { - UpdateRequest updateRequest = new UpdateRequest(DETECTION_STATE_INDEX, taskId); - Map updatedContent = new HashMap<>(); - updatedContent.putAll(updatedFields); - updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); - updateRequest.doc(updatedContent); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.update(updateRequest, listener); - } - - /** - * Delete AD task with task id. - * - * @param taskId AD task id - */ - public void deleteADTask(String taskId) { - deleteADTask(taskId, ActionListener.wrap(r -> { logger.info("Deleted AD task {} with status: {}", taskId, r.status()); }, e -> { - logger.error("Failed to delete AD task " + taskId, e); - })); - } - - /** - * Delete AD task with task id. - * - * @param taskId AD task id - * @param listener action listener - */ - public void deleteADTask(String taskId, ActionListener listener) { - DeleteRequest deleteRequest = new DeleteRequest(DETECTION_STATE_INDEX, taskId); - client.delete(deleteRequest, listener); + listener.onResponse(responseTransformer.transform(response)); + }, e -> listener.onFailure(e))); } /** @@ -1847,7 +933,7 @@ public void deleteADTask(String taskId, ActionListener listener) * @return AD task cancellation state */ public ADTaskCancellationState cancelLocalTaskByDetectorId(String detectorId, String detectorTaskId, String reason, String userName) { - ADTaskCancellationState cancellationState = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); + ADTaskCancellationState cancellationState = taskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); logger .debug( "Cancelled AD task for detector: {}, state: {}, cancelled by: {}, reason: {}", @@ -1859,199 +945,6 @@ public ADTaskCancellationState cancelLocalTaskByDetectorId(String detectorId, St return cancellationState; } - /** - * Delete AD tasks docs. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param function AD function - * @param listener action listener - */ - public void deleteADTasks(String detectorId, ExecutorFunction function, ActionListener listener) { - DeleteByQueryRequest request = new DeleteByQueryRequest(DETECTION_STATE_INDEX); - - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); - - request.setQuery(query); - client.execute(DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(r -> { - if (r.getBulkFailures() == null || r.getBulkFailures().size() == 0) { - logger.info("AD tasks deleted for detector {}", detectorId); - deleteADResultOfDetector(detectorId); - function.execute(); - } else { - listener.onFailure(new OpenSearchStatusException("Failed to delete all AD tasks", RestStatus.INTERNAL_SERVER_ERROR)); - } - }, e -> { - logger.info("Failed to delete AD tasks for " + detectorId, e); - if (e instanceof IndexNotFoundException) { - deleteADResultOfDetector(detectorId); - function.execute(); - } else { - listener.onFailure(e); - } - })); - } - - private void deleteADResultOfDetector(String detectorId) { - if (!deleteADResultWhenDeleteDetector) { - logger.info("Won't delete ad result for {} as delete AD result setting is disabled", detectorId); - return; - } - logger.info("Start to delete AD results of detector {}", detectorId); - DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN); - deleteADResultsRequest.setQuery(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); - client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(response -> { - logger.debug("Successfully deleted AD results of detector " + detectorId); - }, exception -> { - logger.error("Failed to delete AD results of detector " + detectorId, exception); - adTaskCacheManager.addDeletedConfig(detectorId); - })); - } - - /** - * Clean AD results of deleted detector. - */ - public void cleanADResultOfDeletedDetector() { - String detectorId = adTaskCacheManager.pollDeletedConfig(); - if (detectorId != null) { - deleteADResultOfDetector(detectorId); - } - } - - /** - * Update latest AD task of detector. - * - * @param detectorId detector id - * @param taskTypes task types - * @param updatedFields updated fields, key: filed name, value: new value - * @param listener action listener - */ - public void updateLatestADTask( - String detectorId, - List taskTypes, - Map updatedFields, - ActionListener listener - ) { - getAndExecuteOnLatestDetectorLevelTask(detectorId, taskTypes, (adTask) -> { - if (adTask.isPresent()) { - updateADTask(adTask.get().getTaskId(), updatedFields, listener); - } else { - listener.onFailure(new ResourceNotFoundException(detectorId, CAN_NOT_FIND_LATEST_TASK)); - } - }, null, false, listener); - } - - /** - * Update latest realtime task. - * - * @param detectorId detector id - * @param state task state - * @param error error - * @param transportService transport service - * @param listener action listener - */ - public void stopLatestRealtimeTask( - String detectorId, - TaskState state, - Exception error, - TransportService transportService, - ActionListener listener - ) { - getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTask) -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - Map updatedFields = new HashMap<>(); - updatedFields.put(ADTask.STATE_FIELD, state.name()); - if (error != null) { - updatedFields.put(ADTask.ERROR_FIELD, error.getMessage()); - } - ExecutorFunction function = () -> updateADTask(adTask.get().getTaskId(), updatedFields, ActionListener.wrap(r -> { - if (error == null) { - listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK)); - } else { - listener.onFailure(error); - } - }, e -> { listener.onFailure(e); })); - - String coordinatingNode = adTask.get().getCoordinatingNode(); - if (coordinatingNode != null && transportService != null) { - cleanDetectorCache(adTask.get(), transportService, function, listener); - } else { - function.execute(); - } - } else { - listener.onFailure(new OpenSearchStatusException("Anomaly detector job is already stopped: " + detectorId, RestStatus.OK)); - } - }, null, false, listener); - } - - /** - * Update realtime task cache on realtime detector's coordinating node. - * - * @param detectorId detector id - * @param state new state - * @param rcfTotalUpdates rcf total updates - * @param detectorIntervalInMinutes detector interval in minutes - * @param error error - * @param listener action listener - */ - public void updateLatestRealtimeTaskOnCoordinatingNode( - String detectorId, - String state, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error, - ActionListener listener - ) { - Float initProgress = null; - String newState = null; - // calculate init progress and task state with RCF total updates - if (detectorIntervalInMinutes != null && rcfTotalUpdates != null) { - newState = TaskState.INIT.name(); - if (rcfTotalUpdates < NUM_MIN_SAMPLES) { - initProgress = (float) rcfTotalUpdates / NUM_MIN_SAMPLES; - } else { - newState = TaskState.RUNNING.name(); - initProgress = 1.0f; - } - } - // Check if new state is not null and override state calculated with rcf total updates - if (state != null) { - newState = state; - } - - error = Optional.ofNullable(error).orElse(""); - if (!adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId, newState, initProgress, error)) { - // If task not changed, no need to update, just return - listener.onResponse(null); - return; - } - Map updatedFields = new HashMap<>(); - updatedFields.put(COORDINATING_NODE_FIELD, clusterService.localNode().getId()); - if (initProgress != null) { - updatedFields.put(INIT_PROGRESS_FIELD, initProgress); - updatedFields.put(ESTIMATED_MINUTES_LEFT_FIELD, Math.max(0, NUM_MIN_SAMPLES - rcfTotalUpdates) * detectorIntervalInMinutes); - } - if (newState != null) { - updatedFields.put(STATE_FIELD, newState); - } - if (error != null) { - updatedFields.put(ERROR_FIELD, error); - } - Float finalInitProgress = initProgress; - // Variable used in lambda expression should be final or effectively final - String finalError = error; - String finalNewState = newState; - updateLatestADTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, updatedFields, ActionListener.wrap(r -> { - logger.debug("Updated latest realtime AD task successfully for detector {}", detectorId); - adTaskCacheManager.updateRealtimeTaskCache(detectorId, finalNewState, finalInitProgress, finalError); - listener.onResponse(r); - }, e -> { - logger.error("Failed to update realtime task for detector " + detectorId, e); - listener.onFailure(e); - })); - } - /** * Init realtime task cache and clean up realtime task cache on old coordinating node. Realtime AD * depends on job scheduler to choose node (job coordinating node) to run AD job. Nodes have primary @@ -2065,33 +958,37 @@ public void updateLatestRealtimeTaskOnCoordinatingNode( * listener will return false. * * @param detectorId detector id - * @param detector anomaly detector + * @param config config accessor * @param transportService transport service * @param listener listener */ + @Override public void initRealtimeTaskCacheAndCleanupStaleCache( String detectorId, - AnomalyDetector detector, + Config config, TransportService transportService, ActionListener listener ) { try { - if (adTaskCacheManager.getRealtimeTaskCache(detectorId) != null) { + if (taskCacheManager.getRealtimeTaskCache(detectorId) != null) { listener.onResponse(false); return; } - getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTaskOptional) -> { + AnomalyDetector detector = (AnomalyDetector) config; + getAndExecuteOnLatestConfigLevelTask(detectorId, REALTIME_TASK_TYPES, (adTaskOptional) -> { if (!adTaskOptional.isPresent()) { logger.debug("Can't find realtime task for detector {}, init realtime task cache directly", detectorId); - ExecutorFunction function = () -> createNewADTask( + ExecutorFunction function = () -> createNewTask( detector, null, + false, detector.getUser(), clusterService.localNode().getId(), + TaskState.CREATED, ActionListener.wrap(r -> { logger.info("Recreate realtime task successfully for detector {}", detectorId); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); }, e -> { logger.error("Failed to recreate realtime task for detector " + detectorId, e); @@ -2113,19 +1010,19 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( localNodeId, detectorId ); - cleanDetectorCache(adTask, transportService, () -> { + cleanConfigCache(adTask, transportService, () -> { logger .info( "Realtime task cache cleaned on old coordinating node {} for detector {}", oldCoordinatingNode, detectorId ); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); }, listener); } else { logger.info("Init realtime task cache for detector {}", detectorId); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); } }, transportService, false, listener); @@ -2136,16 +1033,16 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( } private void recreateRealtimeTask(ExecutorFunction function, ActionListener listener) { - if (detectionIndices.doesStateIndexExist()) { + if (indexManagement.doesStateIndexExist()) { function.execute(); } else { // If detection index doesn't exist, create index and execute function. - detectionIndices.initStateIndex(ActionListener.wrap(r -> { + indexManagement.initStateIndex(ActionListener.wrap(r -> { if (r.isAcknowledged()) { logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); function.execute(); } else { - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); logger.warn(error); listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); } @@ -2160,14 +1057,6 @@ private void recreateRealtimeTask(ExecutorFunction function, ActionListener listener + ActionListener listener ) { try { ADTaskAction action = getAdEntityTaskAction(adTask, exception); @@ -2225,7 +1114,7 @@ private void entityTaskDone( private ADTaskAction getAdEntityTaskAction(ADTask adTask, Exception exception) { ADTaskAction action = ADTaskAction.NEXT_ENTITY; if (exception != null) { - adTask.setError(getErrorMessage(exception)); + adTask.setError(ExceptionUtil.getErrorMessage(exception)); if (exception instanceof LimitExceededException && isRetryableError(exception.getMessage())) { action = ADTaskAction.PUSH_BACK_ENTITY; } else if (exception instanceof TaskCancelledException || exception instanceof EndRunException) { @@ -2261,14 +1150,14 @@ public boolean isRetryableError(String error) { * @param state AD task state * @param listener action listener */ - public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener listener) { + public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener listener) { String detectorId = adTask.getConfigId(); - String taskId = adTask.isEntityTask() ? adTask.getParentTaskId() : adTask.getTaskId(); + String taskId = adTask.isHistoricalEntityTask() ? adTask.getParentTaskId() : adTask.getTaskId(); String detectorTaskId = adTask.getConfigLevelTaskId(); ActionListener wrappedListener = ActionListener.wrap(response -> { logger.info("Historical HC detector done with state: {}. Remove from cache, detector id:{}", state.name(), detectorId); - adTaskCacheManager.removeHistoricalTaskCache(detectorId); + taskCacheManager.removeHistoricalTaskCache(detectorId); }, e -> { // HC detector task may fail to update as FINISHED for some edge case if failed to get updating semaphore. // Will reset task state when get detector with task or maintain tasks in hourly cron. @@ -2277,7 +1166,7 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener } else { logger.error("Failed to update task: " + taskId, e); } - adTaskCacheManager.removeHistoricalTaskCache(detectorId); + taskCacheManager.removeHistoricalTaskCache(detectorId); }); long timeoutInMillis = 2000;// wait for 2 seconds to acquire updating HC detector task semaphore @@ -2293,11 +1182,11 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, hcDetectorTaskState.name(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2307,20 +1196,20 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener }, e -> { logger.error("Failed to get finished entity tasks", e); - String errorMessage = getErrorMessage(e); + String errorMessage = ExceptionUtil.getErrorMessage(e); threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { updateADHCDetectorTask( detectorId, taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.FAILED.name(),// set as FAILED if fail to get finished entity tasks. - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0, - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, errorMessage, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2335,11 +1224,11 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, state.name(), - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, adTask.getError(), - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2349,7 +1238,7 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener } - listener.onResponse(new AnomalyDetectorJobResponse(taskId, 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(taskId)); } /** @@ -2361,9 +1250,12 @@ public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener */ public void countEntityTasksByState(String detectorTaskId, List taskStates, ActionListener listener) { BoolQueryBuilder queryBuilder = new BoolQueryBuilder(); - queryBuilder.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); + queryBuilder.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, detectorTaskId)); if (taskStates != null && taskStates.size() > 0) { - queryBuilder.filter(new TermsQueryBuilder(STATE_FIELD, taskStates.stream().map(s -> s.name()).collect(Collectors.toList()))); + queryBuilder + .filter( + new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, taskStates.stream().map(s -> s.name()).collect(Collectors.toList())) + ); } SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(queryBuilder); @@ -2426,19 +1318,19 @@ private void updateADHCDetectorTask( ActionListener listener ) { try { - if (adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, timeoutInMillis)) { + if (taskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, timeoutInMillis)) { try { - updateADTask( + updateTask( taskId, updatedFields, - ActionListener.runAfter(listener, () -> { adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); }) + ActionListener.runAfter(listener, () -> { taskCacheManager.releaseTaskUpdatingSemaphore(detectorId); }) ); } catch (Exception e) { logger.error("Failed to update detector task " + taskId, e); - adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); + taskCacheManager.releaseTaskUpdatingSemaphore(detectorId); listener.onFailure(e); } - } else if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + } else if (!taskCacheManager.isHCTaskCoordinatingNode(detectorId)) { // It's possible that AD task cache cleaned up by other task. Return null to avoid too many failure logs. logger.info("HC detector task cache does not exist, detectorId:{}, taskId:{}", detectorId, taskId); listener.onResponse(null); @@ -2463,11 +1355,7 @@ private void updateADHCDetectorTask( * @param transportService transport service * @param listener action listener */ - public void runNextEntityForHCADHistorical( - ADTask adTask, - TransportService transportService, - ActionListener listener - ) { + public void runNextEntityForHCADHistorical(ADTask adTask, TransportService transportService, ActionListener listener) { String detectorId = adTask.getConfigId(); int scaleDelta = scaleTaskSlots(adTask, transportService, ActionListener.wrap(r -> { logger.debug("Scale up task slots done for detector {}, task {}", detectorId, adTask.getTaskId()); @@ -2478,9 +1366,9 @@ public void runNextEntityForHCADHistorical( "Have scaled down task slots. Will not poll next entity for detector {}, task {}, task slots: {}", detectorId, adTask.getTaskId(), - adTaskCacheManager.getDetectorTaskSlots(detectorId) + taskCacheManager.getDetectorTaskSlots(detectorId) ); - listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.ACCEPTED)); + listener.onResponse(new JobResponse(detectorId)); return; } client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { @@ -2493,7 +1381,7 @@ public void runNextEntityForHCADHistorical( remoteOrLocal, r.getNodeId() ); - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK); + JobResponse anomalyDetectorJobResponse = new JobResponse(detectorId); listener.onResponse(anomalyDetectorJobResponse); }, e -> { listener.onFailure(e); })); } @@ -2508,11 +1396,7 @@ public void runNextEntityForHCADHistorical( * @param scaleUpListener action listener * @return task slots scale delta */ - protected int scaleTaskSlots( - ADTask adTask, - TransportService transportService, - ActionListener scaleUpListener - ) { + protected int scaleTaskSlots(ADTask adTask, TransportService transportService, ActionListener scaleUpListener) { String detectorId = adTask.getConfigId(); if (!scaleEntityTaskLane.tryAcquire()) { logger.debug("Can't get scaleEntityTaskLane semaphore"); @@ -2521,9 +1405,9 @@ protected int scaleTaskSlots( try { int scaleDelta = detectorTaskSlotScaleDelta(detectorId); logger.debug("start to scale task slots for detector {} with delta {}", detectorId, scaleDelta); - if (adTaskCacheManager.getAvailableNewEntityTaskLanes(detectorId) <= 0 && scaleDelta > 0) { + if (taskCacheManager.getAvailableNewEntityTaskLanes(detectorId) <= 0 && scaleDelta > 0) { // scale up to run more entities in parallel - Instant lastScaleEntityTaskLaneTime = adTaskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); + Instant lastScaleEntityTaskLaneTime = taskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); if (lastScaleEntityTaskLaneTime == null) { logger.debug("lastScaleEntityTaskLaneTime is null for detector {}", detectorId); scaleEntityTaskLane.release(); @@ -2533,7 +1417,7 @@ protected int scaleTaskSlots( .plusMillis(SCALE_ENTITY_TASK_LANE_INTERVAL_IN_MILLIS) .isBefore(Instant.now()); if (lastScaleTimeExpired) { - adTaskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); + taskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); logger.debug("Forward scale entity task lane request to lead node for detector {}", detectorId); forwardScaleTaskSlotRequestToLeadNode( adTask, @@ -2551,9 +1435,9 @@ protected int scaleTaskSlots( } } else { if (scaleDelta < 0) { // scale down to release task slots for other detectors - int runningEntityCount = adTaskCacheManager.getRunningEntityCount(detectorId) + adTaskCacheManager + int runningEntityCount = taskCacheManager.getRunningEntityCount(detectorId) + taskCacheManager .getTempEntityCount(detectorId); - int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int assignedTaskSlots = taskCacheManager.getDetectorTaskSlots(detectorId); int scaleDownDelta = Math.min(assignedTaskSlots - runningEntityCount, 0 - scaleDelta); logger .debug( @@ -2563,7 +1447,7 @@ protected int scaleTaskSlots( runningEntityCount, scaleDownDelta ); - adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); + taskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); } scaleEntityTaskLane.release(); } @@ -2592,13 +1476,13 @@ protected int scaleTaskSlots( * @return detector task slots scale delta */ public int detectorTaskSlotScaleDelta(String detectorId) { - DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalAdVersion(); - int unfinishedEntities = adTaskCacheManager.getUnfinishedEntityCount(detectorId); + DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalVersion(); + int unfinishedEntities = taskCacheManager.getUnfinishedEntityCount(detectorId); int totalTaskSlots = eligibleDataNodes.length * maxAdBatchTaskPerNode; int taskLaneLimit = Math.min(unfinishedEntities, Math.min(totalTaskSlots, maxRunningEntitiesPerDetector)); - adTaskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); + taskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); - int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int assignedTaskSlots = taskCacheManager.getDetectorTaskSlots(detectorId); int scaleDelta = taskLaneLimit - assignedTaskSlots; logger .debug( @@ -2622,8 +1506,8 @@ public int detectorTaskSlotScaleDelta(String detectorId) { * @return task progress */ public float hcDetectorProgress(String detectorId) { - int entityCount = adTaskCacheManager.getTopEntityCount(detectorId); - int leftEntities = adTaskCacheManager.getPendingEntityCount(detectorId) + adTaskCacheManager.getRunningEntityCount(detectorId); + int entityCount = taskCacheManager.getTopEntityCount(detectorId); + int leftEntities = taskCacheManager.getPendingEntityCount(detectorId) + taskCacheManager.getRunningEntityCount(detectorId); return 1 - (float) leftEntities / entityCount; } @@ -2633,39 +1517,39 @@ public float hcDetectorProgress(String detectorId) { * @return list of AD task profile */ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { - List tasksOfDetector = adTaskCacheManager.getTasksOfDetector(detectorId); + List tasksOfDetector = taskCacheManager.getTasksOfDetector(detectorId); ADTaskProfile detectorTaskProfile = null; String localNodeId = clusterService.localNode().getId(); - if (adTaskCacheManager.isHCTaskRunning(detectorId)) { + if (taskCacheManager.isHCTaskRunning(detectorId)) { detectorTaskProfile = new ADTaskProfile(); - if (adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + if (taskCacheManager.isHCTaskCoordinatingNode(detectorId)) { detectorTaskProfile.setNodeId(localNodeId); - detectorTaskProfile.setTaskId(adTaskCacheManager.getDetectorTaskId(detectorId)); - detectorTaskProfile.setDetectorTaskSlots(adTaskCacheManager.getDetectorTaskSlots(detectorId)); - detectorTaskProfile.setTotalEntitiesInited(adTaskCacheManager.topEntityInited(detectorId)); - detectorTaskProfile.setTotalEntitiesCount(adTaskCacheManager.getTopEntityCount(detectorId)); - detectorTaskProfile.setPendingEntitiesCount(adTaskCacheManager.getPendingEntityCount(detectorId)); - detectorTaskProfile.setRunningEntitiesCount(adTaskCacheManager.getRunningEntityCount(detectorId)); - detectorTaskProfile.setRunningEntities(adTaskCacheManager.getRunningEntities(detectorId)); - detectorTaskProfile.setAdTaskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()); - Instant latestHCTaskRunTime = adTaskCacheManager.getLatestHCTaskRunTime(detectorId); + detectorTaskProfile.setTaskId(taskCacheManager.getDetectorTaskId(detectorId)); + detectorTaskProfile.setDetectorTaskSlots(taskCacheManager.getDetectorTaskSlots(detectorId)); + detectorTaskProfile.setTotalEntitiesInited(taskCacheManager.topEntityInited(detectorId)); + detectorTaskProfile.setTotalEntitiesCount(taskCacheManager.getTopEntityCount(detectorId)); + detectorTaskProfile.setPendingEntitiesCount(taskCacheManager.getPendingEntityCount(detectorId)); + detectorTaskProfile.setRunningEntitiesCount(taskCacheManager.getRunningEntityCount(detectorId)); + detectorTaskProfile.setRunningEntities(taskCacheManager.getRunningEntities(detectorId)); + detectorTaskProfile.setTaskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()); + Instant latestHCTaskRunTime = taskCacheManager.getLatestHCTaskRunTime(detectorId); if (latestHCTaskRunTime != null) { detectorTaskProfile.setLatestHCTaskRunTime(latestHCTaskRunTime.toEpochMilli()); } } if (tasksOfDetector.size() > 0) { - List entityTaskProfiles = new ArrayList<>(); + List entityTaskProfiles = new ArrayList<>(); tasksOfDetector.forEach(taskId -> { - ADEntityTaskProfile entityTaskProfile = new ADEntityTaskProfile( - adTaskCacheManager.getShingle(taskId).size(), - adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), - adTaskCacheManager.isThresholdModelTrained(taskId), - adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), - adTaskCacheManager.getModelSize(taskId), + EntityTaskProfile entityTaskProfile = new EntityTaskProfile( + taskCacheManager.getShingle(taskId).size(), + taskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + taskCacheManager.isThresholdModelTrained(taskId), + taskCacheManager.getThresholdModelTrainingDataSize(taskId), + taskCacheManager.getModelSize(taskId), localNodeId, - adTaskCacheManager.getEntity(taskId), + taskCacheManager.getEntity(taskId), taskId, ADTaskType.HISTORICAL_HC_ENTITY.name() ); @@ -2684,12 +1568,12 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { if (tasksOfDetector.size() == 1) { String taskId = tasksOfDetector.get(0); detectorTaskProfile = new ADTaskProfile( - adTaskCacheManager.getDetectorTaskId(detectorId), - adTaskCacheManager.getShingle(taskId).size(), - adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), - adTaskCacheManager.isThresholdModelTrained(taskId), - adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), - adTaskCacheManager.getModelSize(taskId), + taskCacheManager.getDetectorTaskId(detectorId), + taskCacheManager.getShingle(taskId).size(), + taskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + taskCacheManager.isThresholdModelTrained(taskId), + taskCacheManager.getThresholdModelTrainingDataSize(taskId), + taskCacheManager.getModelSize(taskId), localNodeId ); // Single-flow detector only has 1 task slot. @@ -2702,7 +1586,7 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { // Clean expired HC batch task run states as it may exists after HC historical analysis done if user cancel // before querying top entities done. We will clean it in hourly cron, check "maintainRunningHistoricalTasks" // method. Clean it up here when get task profile to release memory earlier. - adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + taskCacheManager.cleanExpiredHCBatchTaskRunStates(); }); logger.debug("Local AD task profile of detector {}: {}", detectorId, detectorTaskProfile); return detectorTaskProfile; @@ -2752,37 +1636,22 @@ public synchronized void removeStaleRunningEntity( ADTask adTask, String entity, TransportService transportService, - ActionListener listener + ActionListener listener ) { String detectorId = adTask.getConfigId(); - boolean removed = adTaskCacheManager.removeRunningEntity(detectorId, entity); - if (removed && adTaskCacheManager.getPendingEntityCount(detectorId) > 0) { + boolean removed = taskCacheManager.removeRunningEntity(detectorId, entity); + if (removed && taskCacheManager.getPendingEntityCount(detectorId) > 0) { logger.debug("kick off next pending entities"); this.runNextEntityForHCADHistorical(adTask, transportService, listener); } else { - if (!adTaskCacheManager.hasEntity(detectorId)) { + if (!taskCacheManager.hasEntity(detectorId)) { setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } } } - public boolean skipUpdateHCRealtimeTask(String detectorId, String error) { - RealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - return realtimeTaskCache != null - && realtimeTaskCache.getInitProgress() != null - && realtimeTaskCache.getInitProgress().floatValue() == 1.0 - && Objects.equals(error, realtimeTaskCache.getError()); - } - - public boolean isHCRealtimeTaskStartInitializing(String detectorId) { - RealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - return realtimeTaskCache != null - && realtimeTaskCache.getInitProgress() != null - && realtimeTaskCache.getInitProgress().floatValue() > 0; - } - public String convertEntityToString(ADTask adTask) { - if (adTask == null || !adTask.isEntityTask()) { + if (adTask == null || !adTask.isHistoricalEntityTask()) { return null; } AnomalyDetector detector = adTask.getDetector(); @@ -2869,45 +1738,8 @@ public void getADTask(String taskId, ActionListener> listener) })); } - /** - * Set old AD task's latest flag as false. - * @param adTasks list of AD tasks - */ - public void resetLatestFlagAsFalse(List adTasks) { - if (adTasks == null || adTasks.size() == 0) { - return; - } - BulkRequest bulkRequest = new BulkRequest(); - adTasks.forEach(task -> { - try { - task.setLatest(false); - task.setLastUpdateTime(Instant.now()); - IndexRequest indexRequest = new IndexRequest(DETECTION_STATE_INDEX) - .id(task.getTaskId()) - .source(task.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)); - bulkRequest.add(indexRequest); - } catch (Exception e) { - logger.error("Fail to parse task AD task to XContent, task id " + task.getTaskId(), e); - } - }); - - bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { - BulkItemResponse[] bulkItemResponses = res.getItems(); - if (bulkItemResponses != null && bulkItemResponses.length > 0) { - for (BulkItemResponse bulkItemResponse : bulkItemResponses) { - if (!bulkItemResponse.isFailed()) { - logger.warn("Reset AD tasks latest flag as false Successfully. Task id: {}", bulkItemResponse.getId()); - } else { - logger.warn("Failed to reset AD tasks latest flag as false. Task id: " + bulkItemResponse.getId()); - } - } - } - }, e -> { logger.warn("Failed to reset AD tasks latest flag as false", e); })); - } - public int getLocalAdUsedBatchTaskSlot() { - return adTaskCacheManager.getTotalBatchTaskCount(); + return taskCacheManager.getTotalBatchTaskCount(); } /** @@ -2933,7 +1765,7 @@ public int getLocalAdUsedBatchTaskSlot() { * @return assigned batch task slots */ public int getLocalAdAssignedBatchTaskSlot() { - return adTaskCacheManager.getTotalDetectorTaskSlots(); + return taskCacheManager.getTotalDetectorTaskSlots(); } // ========================================================= @@ -2953,23 +1785,23 @@ public int getLocalAdAssignedBatchTaskSlot() { */ public void maintainRunningHistoricalTasks(TransportService transportService, int size) { // Clean expired HC batch task run state cache. - adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + taskCacheManager.cleanExpiredHCBatchTaskRunStates(); // Find owning node with highest AD version to make sure we only have 1 node maintain running historical tasks // and we use the latest logic. - Optional owningNode = hashRing.getOwningNodeWithHighestAdVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); + Optional owningNode = hashRing.getOwningNodeWithHighestVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); if (!owningNode.isPresent() || !clusterService.localNode().getId().equals(owningNode.get().getId())) { return; } logger.info("Start to maintain running historical tasks"); BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); - query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); + query.filter(new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, NOT_ENDED_STATES)); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // default maintain interval is 5 seconds, so maintain 10 tasks will take at least 50 seconds. - sourceBuilder.query(query).sort(LAST_UPDATE_TIME_FIELD, SortOrder.DESC).size(size); + sourceBuilder.query(query).sort(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, SortOrder.DESC).size(size); SearchRequest searchRequest = new SearchRequest(); searchRequest.source(sourceBuilder); searchRequest.indices(DETECTION_STATE_INDEX); @@ -3007,7 +1839,7 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue return; } threadPool.schedule(() -> { - resetHistoricalDetectorTaskState(ImmutableList.of(adTask), () -> { + resetHistoricalConfigTaskState(ImmutableList.of(adTask), () -> { logger.debug("Finished maintaining running historical task {}", adTask.getTaskId()); maintainRunningHistoricalTask(taskQueue, transportService); }, transportService, ActionListener.wrap(r -> { @@ -3017,20 +1849,88 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue } /** - * Maintain running realtime tasks. Check if realtime task cache expires or not. Remove realtime - * task cache directly if expired. + * Get list of task types. + * 1. If date range is null, will return all realtime task types + * 2. If date range is not null, will return all historical detector level tasks types + * if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include + * HC entity level task type. + * @param dateRange detection date range + * @param resetLatestTaskStateFlag reset latest task state or not + * @return list of AD task types */ - public void maintainRunningRealtimeTasks() { - String[] detectorIds = adTaskCacheManager.getDetectorIdsInRealtimeTaskCache(); - if (detectorIds == null || detectorIds.length == 0) { - return; + protected List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag) { + // AD does not support run once + return getTaskTypes(dateRange, resetLatestTaskStateFlag, false); + } + + @Override + protected BiCheckedFunction getTaskParser() { + return ADTask::parse; + } + + @Override + public void createRunOnceTaskAndCleanupStaleTasks( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("AD has no run once yet"); + } + + @Override + public List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce) { + if (dateRange == null) { + return REALTIME_TASK_TYPES; + } else { + if (resetLatestTaskStateFlag) { + // return all task types include HC entity task to make sure we can reset all tasks latest flag + return ALL_HISTORICAL_TASK_TYPES; + } else { + return HISTORICAL_DETECTOR_TASK_TYPES; + } } - for (int i = 0; i < detectorIds.length; i++) { - String detectorId = detectorIds[i]; - RealtimeTaskCache taskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - if (taskCache != null && taskCache.expired()) { - adTaskCacheManager.removeRealtimeTaskCache(detectorId); + } + + /** + * Reset latest config task state. Will reset both historical and realtime tasks. + * [Important!] Make sure listener returns in function + * + * @param tasks tasks + * @param function consumer function + * @param transportService transport service + * @param listener action listener + * @param response type of action listener + */ + @Override + protected void resetLatestConfigTaskState( + List tasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ) { + List runningHistoricalTasks = new ArrayList<>(); + List runningRealtimeTasks = new ArrayList<>(); + + for (TimeSeriesTask task : tasks) { + if (!task.isHistoricalEntityTask() && !task.isDone()) { + if (task.isRealTimeTask()) { + runningRealtimeTasks.add(task); + } else if (task.isHistoricalTask()) { + runningHistoricalTasks.add(task); + } } } + + // resetRealtimeCOnfigTaskState has to be the innermost function call as we return listener there + // AD has no run once and forecasting has no historical. So the run once and historical reset + // function only forwards function call and does not return listener + resetHistoricalConfigTaskState( + runningHistoricalTasks, + () -> resetRealtimeConfigTaskState(runningRealtimeTasks, () -> function.accept(tasks), transportService, listener), + transportService, + listener + ); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java index 84fe0c6fe..df6194353 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java @@ -14,10 +14,10 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADBatchAnomalyResultAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK; public static final ADBatchAnomalyResultAction INSTANCE = new ADBatchAnomalyResultAction(); private ADBatchAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java index d865ec14c..84a22b261 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java @@ -14,10 +14,10 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK_REMOTE; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADBatchTaskRemoteExecutionAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK_REMOTE; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK_REMOTE; public static final ADBatchTaskRemoteExecutionAction INSTANCE = new ADBatchTaskRemoteExecutionAction(); private ADBatchTaskRemoteExecutionAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java index 31f20fa00..d20759f70 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java @@ -14,11 +14,11 @@ import static org.opensearch.ad.constant.ADCommonName.CANCEL_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADCancelTaskAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/" + CANCEL_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/" + CANCEL_TASK; public static final ADCancelTaskAction INSTANCE = new ADCancelTaskAction(); private ADCancelTaskAction() { diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java b/src/main/java/org/opensearch/ad/transport/ADEntityProfileAction.java similarity index 54% rename from src/main/java/org/opensearch/ad/transport/EntityProfileAction.java rename to src/main/java/org/opensearch/ad/transport/ADEntityProfileAction.java index c699d9a03..11e6a44a4 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADEntityProfileAction.java @@ -12,14 +12,15 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.EntityProfileResponse; -public class EntityProfileAction extends ActionType { +public class ADEntityProfileAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/entity"; - public static final EntityProfileAction INSTANCE = new EntityProfileAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/entity"; + public static final ADEntityProfileAction INSTANCE = new ADEntityProfileAction(); - private EntityProfileAction() { + private ADEntityProfileAction() { super(NAME, EntityProfileResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/ADEntityProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADEntityProfileTransportAction.java new file mode 100644 index 000000000..5ffde2999 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADEntityProfileTransportAction.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.transport.BaseEntityProfileTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Transport action to get entity profile. + */ +public class ADEntityProfileTransportAction extends + BaseEntityProfileTransportAction { + + @Inject + public ADEntityProfileTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + HashRing hashRing, + ClusterService clusterService, + ADCacheProvider cacheProvider + ) { + super( + actionFilters, + transportService, + settings, + hashRing, + clusterService, + cacheProvider, + ADEntityProfileAction.NAME, + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT + ); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileAction.java b/src/main/java/org/opensearch/ad/transport/ADProfileAction.java similarity index 59% rename from src/main/java/org/opensearch/ad/transport/ProfileAction.java rename to src/main/java/org/opensearch/ad/transport/ADProfileAction.java index 291dd0982..1d51add9e 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADProfileAction.java @@ -12,20 +12,21 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.ProfileResponse; /** * Profile transport action */ -public class ProfileAction extends ActionType { +public class ADProfileAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile"; - public static final ProfileAction INSTANCE = new ProfileAction(); + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detectors/profile"; + public static final ADProfileAction INSTANCE = new ADProfileAction(); /** * Constructor */ - private ProfileAction() { + private ADProfileAction() { super(NAME, ProfileResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/ADProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADProfileTransportAction.java new file mode 100644 index 000000000..af7c40bb4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADProfileTransportAction.java @@ -0,0 +1,114 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.BaseProfileTransportAction; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * This class contains the logic to extract the stats from the nodes + */ +public class ADProfileTransportAction extends BaseProfileTransportAction { + private ADModelManager modelManager; + private FeatureManager featureManager; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param modelManager model manager object + * @param featureManager feature manager object + * @param cacheProvider cache provider + * @param settings Node settings accessor + */ + @Inject + public ADProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cacheProvider, + Settings settings + ) { + super( + ADProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + cacheProvider, + settings, + AD_MAX_MODEL_SIZE_PER_NODE + ); + this.modelManager = modelManager; + this.featureManager = featureManager; + } + + @Override + protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { + String detectorId = request.getConfigId(); + Set profiles = request.getProfilesToBeRetrieved(); + int shingleSize = -1; + long activeEntity = 0; + long totalUpdates = 0; + Map modelSize = null; + List modelProfiles = null; + int modelCount = 0; + if (request.isModelInPriorityCache()) { + super.nodeOperation(request); + } else { + if (profiles.contains(ProfileName.COORDINATING_NODE) || profiles.contains(ProfileName.SHINGLE_SIZE)) { + shingleSize = featureManager.getShingleSize(detectorId); + } + + if (profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES) || profiles.contains(ProfileName.MODELS)) { + modelSize = modelManager.getModelSize(detectorId); + } + } + + return new ProfileNodeResponse( + clusterService.localNode(), + modelSize, + shingleSize, + activeEntity, + totalUpdates, + modelProfiles, + modelCount + ); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java index 041d543b7..e54a4747e 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java @@ -12,18 +12,19 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.transport.ResultBulkResponse; import org.opensearch.transport.TransportRequestOptions; -public class ADResultBulkAction extends ActionType { +public class ADResultBulkAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; public static final ADResultBulkAction INSTANCE = new ADResultBulkAction(); private ADResultBulkAction() { - super(NAME, ADResultBulkResponse::new); + super(NAME, ResultBulkResponse::new); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java index f5f361f69..0f8430a25 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java @@ -12,73 +12,19 @@ package org.opensearch.ad.transport; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.action.ValidateActions; -import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.timeseries.transport.ResultBulkRequest; -public class ADResultBulkRequest extends ActionRequest implements Writeable { - private final List anomalyResults; - static final String NO_REQUESTS_ADDED_ERR = "no requests added"; +public class ADResultBulkRequest extends ResultBulkRequest { public ADResultBulkRequest() { - anomalyResults = new ArrayList<>(); + super(); } public ADResultBulkRequest(StreamInput in) throws IOException { - super(in); - int size = in.readVInt(); - anomalyResults = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - anomalyResults.add(new ResultWriteRequest(in)); - } - } - - @Override - public ActionRequestValidationException validate() { - ActionRequestValidationException validationException = null; - if (anomalyResults.isEmpty()) { - validationException = ValidateActions.addValidationError(NO_REQUESTS_ADDED_ERR, validationException); - } - return validationException; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeVInt(anomalyResults.size()); - for (ResultWriteRequest result : anomalyResults) { - result.writeTo(out); - } - } - - /** - * - * @return all of the results to send - */ - public List getAnomalyResults() { - return anomalyResults; - } - - /** - * Add result to send - * @param resultWriteRequest The result write request - */ - public void add(ResultWriteRequest resultWriteRequest) { - anomalyResults.add(resultWriteRequest); - } - - /** - * - * @return total index requests - */ - public int numberOfActions() { - return anomalyResults.size(); + super(in, ADResultWriteRequest::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java index 03ca7657c..d5442e57c 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java @@ -14,45 +14,31 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.index.IndexingPressure.MAX_INDEXING_BYTES; import java.io.IOException; import java.util.List; -import java.util.Random; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.bulk.BulkAction; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.ad.util.BulkUtil; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexingPressure; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.ResultBulkTransportAction; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; -public class ADResultBulkTransportAction extends HandledTransportAction { +public class ADResultBulkTransportAction extends ResultBulkTransportAction { private static final Logger LOG = LogManager.getLogger(ADResultBulkTransportAction.class); - private IndexingPressure indexingPressure; - private final long primaryAndCoordinatingLimits; - private float softLimit; - private float hardLimit; - private String indexName; - private Client client; - private Random random; @Inject public ADResultBulkTransportAction( @@ -63,69 +49,51 @@ public ADResultBulkTransportAction( ClusterService clusterService, Client client ) { - super(ADResultBulkAction.NAME, transportService, actionFilters, ADResultBulkRequest::new, ThreadPool.Names.SAME); - this.indexingPressure = indexingPressure; - this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); - this.softLimit = AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings); - this.hardLimit = AD_INDEX_PRESSURE_HARD_LIMIT.get(settings); - this.indexName = ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; - this.client = client; + super( + ADResultBulkAction.NAME, + transportService, + actionFilters, + indexingPressure, + settings, + client, + AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings), + AD_INDEX_PRESSURE_HARD_LIMIT.get(settings), + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + ADResultBulkRequest::new + ); clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); - // random seed is 42. Can be any number - this.random = new Random(42); } @Override - protected void doExecute(Task task, ADResultBulkRequest request, ActionListener listener) { - // Concurrent indexing memory limit = 10% of heap - // indexing pressure = indexing bytes / indexing limit - // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index - // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). - long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); - float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; - List results = request.getAnomalyResults(); - - if (results == null || results.size() < 1) { - listener.onResponse(new ADResultBulkResponse()); - } - + protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ADResultBulkRequest request) { BulkRequest bulkRequest = new BulkRequest(); + List results = request.getAnomalyResults(); if (indexingPressurePercent <= softLimit) { - for (ResultWriteRequest resultWriteRequest : results) { - addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getCustomResultIndex()); + for (ADResultWriteRequest resultWriteRequest : results) { + addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getResultIndex()); } } else if (indexingPressurePercent <= hardLimit) { // exceed soft limit (60%) but smaller than hard limit (90%) float acceptProbability = 1 - indexingPressurePercent; - for (ResultWriteRequest resultWriteRequest : results) { + for (ADResultWriteRequest resultWriteRequest : results) { AnomalyResult result = resultWriteRequest.getResult(); if (result.isHighPriority() || random.nextFloat() < acceptProbability) { - addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); } } } else { // if exceeding hard limit, only index non-zero grade or error result - for (ResultWriteRequest resultWriteRequest : results) { + for (ADResultWriteRequest resultWriteRequest : results) { AnomalyResult result = resultWriteRequest.getResult(); if (result.isHighPriority()) { - addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); } } } - if (bulkRequest.numberOfActions() > 0) { - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { - List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); - listener.onResponse(new ADResultBulkResponse(failedRequests)); - }, e -> { - LOG.error("Failed to bulk index AD result", e); - listener.onFailure(e); - })); - } else { - listener.onResponse(new ADResultBulkResponse()); - } + return bulkRequest; } private void addResult(BulkRequest bulkRequest, AnomalyResult result, String resultIndex) { diff --git a/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java new file mode 100644 index 000000000..b5acfbaea --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java @@ -0,0 +1,497 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_PAGE_SIZE; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ADResultProcessor extends + ResultProcessor { + private static final Logger LOG = LogManager.getLogger(ADResultProcessor.class); + + private final ADModelManager adModelManager; + + public ADResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + ADStats timeSeriesStats, + ADTaskManager realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager, + ADModelManager adModelManager + ) { + super( + requestTimeoutSetting, + intervalRatioForRequests, + entityResultAction, + hcRequestCountStat, + settings, + clusterService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + hashRing, + nodeStateManager, + transportService, + timeSeriesStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + transportResultResponseClazz, + featureManager, + AD_MAX_ENTITIES_PER_QUERY, + AD_PAGE_SIZE, + AnalysisType.AD, + false + ); + this.adModelManager = adModelManager; + } + + // For single stream detector + @Override + protected ActionListener onFeatureResponseForSingleStreamConfig( + String adID, + Config config, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime, + String taskId + ) { + return ActionListener.wrap(featureOptional -> { + List featureInResponse = null; + AnomalyDetector detector = (AnomalyDetector) config; + if (featureOptional.getUnprocessedFeatures().isPresent()) { + featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); + } + + if (!featureOptional.getProcessedFeatures().isPresent()) { + + Optional exception = coldStartIfNoCheckPoint(detector); + if (exception.isPresent()) { + listener.onFailure(exception.get()); + return; + } + + if (!featureOptional.getUnprocessedFeatures().isPresent()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current detection window between {} and {} for {}", dataStartTime, dataEndTime, adID); + listener + .onResponse( + createResultResponse( + new ArrayList(), + "No data in current detection window", + null, + null, + false, + taskId + ) + ); + + } else { + LOG.debug("Return at least current feature value between {} and {} for {}", dataStartTime, dataEndTime, adID); + listener + .onResponse( + createResultResponse( + featureInResponse, + "No full shingle in current detection window", + null, + null, + false, + taskId + ) + ); + } + return; + } + + final AtomicReference failure = new AtomicReference(); + + LOG.info("Sending RCF request to {} for model {}", rcfNode.getId(), rcfModelId); + + RCFActionListener rcfListener = new RCFActionListener( + rcfModelId, + failure, + rcfNode.getId(), + detector, + listener, + featureInResponse, + adID + ); + + // The threshold for splitting RCF models in single-stream detectors. + // The smallest machine in the Amazon managed service has 1GB heap. + // With the setting, the desired model size there is of 2 MB. + // By default, we can have at most 5 features. Since the default shingle size + // is 8, we have at most 40 dimensions in RCF. In our current RCF setting, + // 30 trees, and bounding box cache ratio 0, 40 dimensions use 449KB. + // Users can increase the number of features to 10 and shingle size to 60, + // 30 trees, bounding box cache ratio 0, 600 dimensions use 1.8 MB. + // Since these sizes are smaller than the threshold 2 MB, we won't split models + // even in the smallest machine. + transportService + .sendRequest( + rcfNode, + RCFResultAction.NAME, + new RCFResultRequest(adID, rcfModelId, featureOptional.getProcessedFeatures().get()), + option, + new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) + ); + }, exception -> { handleQueryFailure(exception, listener, adID); }); + } + + // For single stream detector + class RCFActionListener implements ActionListener { + private String modelID; + private AtomicReference failure; + private String rcfNodeID; + private Config detector; + private ActionListener listener; + private List featureInResponse; + private final String adID; + + RCFActionListener( + String modelID, + AtomicReference failure, + String rcfNodeID, + Config detector, + ActionListener listener, + List features, + String adID + ) { + this.modelID = modelID; + this.failure = failure; + this.rcfNodeID = rcfNodeID; + this.detector = detector; + this.listener = listener; + this.featureInResponse = features; + this.adID = adID; + } + + @Override + public void onResponse(RCFResultResponse response) { + try { + nodeStateManager.resetBackpressureCounter(rcfNodeID, adID); + if (response != null) { + listener + .onResponse( + new AnomalyResultResponse( + response.getAnomalyGrade(), + response.getConfidence(), + response.getRCFScore(), + featureInResponse, + null, + response.getTotalUpdates(), + detector.getIntervalInMinutes(), + false, + response.getRelativeIndex(), + response.getAttribution(), + response.getPastValues(), + response.getExpectedValuesList(), + response.getLikelihoodOfValues(), + response.getThreshold(), + null + ) + ); + } else { + LOG.warn(ResultProcessor.NULL_RESPONSE + " {} for {}", modelID, rcfNodeID); + listener.onFailure(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + ResultProcessor.handleExecuteException(ex, listener, adID); + } + } + + @Override + public void onFailure(Exception e) { + try { + handlePredictionFailure(e, adID, rcfNodeID, failure); + Exception exception = coldStartIfNoModel(failure, detector); + if (exception != null) { + listener.onFailure(exception); + } else { + listener.onFailure(new InternalFailure(adID, "Node connection problem or unexpected exception")); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + ResultProcessor.handleExecuteException(ex, listener, adID); + } + } + } + + /** + * Verify failure of rcf or threshold models. If there is no model, trigger cold + * start. If there is an exception for the previous cold start of this detector, + * throw exception to the caller. + * + * @param failure object that may contain exceptions thrown + * @param detector detector object + * @return exception if AD job execution gets resource not found exception + * @throws Exception when the input failure is not a ResourceNotFoundException. + * List of exceptions we can throw + * 1. Exception from cold start: + * 1). InternalFailure due to + * a. OpenSearchTimeoutException thrown by putModelCheckpoint during cold start + * 2). EndRunException with endNow equal to false + * a. training data not available + * b. cold start cannot succeed + * c. invalid training data + * 3) EndRunException with endNow equal to true + * a. invalid search query + * 2. LimitExceededException from one of RCF model node when the total size of the models + * is more than X% of heap memory. + * 3. InternalFailure wrapping OpenSearchTimeoutException inside caused by + * RCF/Threshold model node failing to get checkpoint to restore model before timeout. + */ + private Exception coldStartIfNoModel(AtomicReference failure, Config detector) throws Exception { + Exception exp = failure.get(); + if (exp == null) { + return null; + } + + // return exceptions like LimitExceededException to caller + if (!(exp instanceof ResourceNotFoundException)) { + return exp; + } + + // fetch previous cold start exception + String adID = detector.getId(); + final Optional previousException = nodeStateManager.fetchExceptionAndClear(adID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return exception; + } + } + LOG.info("Trigger cold start for {}", detector.getId()); + // only used in single-stream anomaly detector thus type cast + coldStart((AnomalyDetector) detector); + return previousException.orElse(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); + } + + // only used for single-stream anomaly detector + private void coldStart(AnomalyDetector detector) { + String detectorId = detector.getId(); + + // If last cold start is not finished, we don't trigger another one + if (nodeStateManager.isColdStartRunning(detectorId)) { + return; + } + + final Releasable coldStartFinishingCallback = nodeStateManager.markColdStartRunning(detectorId); + + ActionListener> listener = ActionListener.wrap(trainingData -> { + if (trainingData.isPresent()) { + double[][] dataPoints = trainingData.get(); + + ActionListener trainModelListener = ActionListener + .wrap(res -> { LOG.info("Succeeded in training {}", detectorId); }, exception -> { + if (exception instanceof TimeSeriesException) { + // e.g., partitioned model exceeds memory limit + nodeStateManager.setException(detectorId, exception); + } else if (exception instanceof IllegalArgumentException) { + // IllegalArgumentException due to invalid training data + nodeStateManager + .setException(detectorId, new EndRunException(detectorId, "Invalid training data", exception, false)); + } else if (exception instanceof OpenSearchTimeoutException) { + nodeStateManager + .setException( + detectorId, + new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) + ); + } else { + nodeStateManager + .setException(detectorId, new EndRunException(detectorId, "Error while training model", exception, false)); + } + }); + + adModelManager + .trainModel( + detector, + dataPoints, + new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + trainModelListener, + false + ) + ); + } else { + nodeStateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); + } + }, exception -> { + if (exception instanceof OpenSearchTimeoutException) { + nodeStateManager + .setException(detectorId, new InternalFailure(detectorId, "Time out while getting training data", exception)); + } else if (exception instanceof TimeSeriesException) { + // e.g., Invalid search query + nodeStateManager.setException(detectorId, exception); + } else { + nodeStateManager.setException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); + } + }); + + final ActionListener> listenerWithReleaseCallback = ActionListener + .runAfter(listener, coldStartFinishingCallback::close); + + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute( + () -> featureManager + .getColdStartData( + detector, + new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + listenerWithReleaseCallback, + false + ) + ) + ); + } + + /** + * Check if checkpoint for an detector exists or not. If not and previous + * run is not EndRunException whose endNow is true, trigger cold start. + * @param detector detector object + * @return previous cold start exception + */ + private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { + String detectorId = detector.getId(); + + Optional previousException = nodeStateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of {}:", detectorId), exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return previousException; + } + } + + nodeStateManager.getDetectorCheckpoint(detectorId, ActionListener.wrap(checkpointExists -> { + if (!checkpointExists) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } + }, exception -> { + Throwable cause = ExceptionsHelper.unwrapCause(exception); + if (cause instanceof IndexNotFoundException) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } else { + String errorMsg = String.format(Locale.ROOT, "Fail to get checkpoint state for %s", detectorId); + LOG.error(errorMsg, exception); + nodeStateManager.setException(detectorId, new TimeSeriesException(errorMsg, exception)); + } + })); + + return previousException; + } + + @Override + protected void findException(Throwable cause, String adID, AtomicReference failure, String nodeId) { + if (cause == null) { + LOG.error(new ParameterizedMessage("Null input exception")); + return; + } + + Exception causeException = (Exception) cause; + + if (causeException instanceof IndexNotFoundException && causeException.getMessage().contains(ADCommonName.CHECKPOINT_INDEX_NAME)) { + // checkpoint index does not exist + // ResourceNotFoundException will trigger cold start later + failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); + } + super.findException(cause, adID, failure, nodeId); + } + + @Override + protected AnomalyResultResponse createResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ) { + return new AnomalyResultResponse(features, error, rcfTotalUpdates, configInterval, isHC, taskId); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java index f6f39ab85..d6fa4c64b 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java @@ -12,22 +12,23 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.StatsNodesResponse; /** * ADStatsNodesAction class */ -public class ADStatsNodesAction extends ActionType { +public class ADStatsNodesAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; public static final ADStatsNodesAction INSTANCE = new ADStatsNodesAction(); /** * Constructor */ private ADStatsNodesAction() { - super(NAME, ADStatsNodesResponse::new); + super(NAME, StatsNodesResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java index 17a81da0a..bfaacbef1 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java @@ -11,32 +11,27 @@ package org.opensearch.ad.transport; -import java.io.IOException; -import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Set; -import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.InternalStatNames; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.stats.InternalStatNames; +import org.opensearch.timeseries.transport.BaseStatsNodesTransportAction; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsRequest; import org.opensearch.transport.TransportService; /** * ADStatsNodesTransportAction contains the logic to extract the stats from the nodes */ -public class ADStatsNodesTransportAction extends - TransportNodesAction { +public class ADStatsNodesTransportAction extends BaseStatsNodesTransportAction { - private ADStats adStats; private final JvmService jvmService; private final ADTaskManager adTaskManager; @@ -47,7 +42,7 @@ public class ADStatsNodesTransportAction extends * @param clusterService ClusterService * @param transportService TransportService * @param actionFilters Action Filters - * @param adStats ADStats object + * @param adStats TimeSeriesStats object * @param jvmService ES JVM Service * @param adTaskManager AD task manager */ @@ -61,48 +56,14 @@ public ADStatsNodesTransportAction( JvmService jvmService, ADTaskManager adTaskManager ) { - super( - ADStatsNodesAction.NAME, - threadPool, - clusterService, - transportService, - actionFilters, - ADStatsRequest::new, - ADStatsNodeRequest::new, - ThreadPool.Names.MANAGEMENT, - ADStatsNodeResponse.class - ); - this.adStats = adStats; + super(threadPool, clusterService, transportService, actionFilters, adStats, ADStatsNodesAction.NAME); this.jvmService = jvmService; this.adTaskManager = adTaskManager; } @Override - protected ADStatsNodesResponse newResponse( - ADStatsRequest request, - List responses, - List failures - ) { - return new ADStatsNodesResponse(clusterService.getClusterName(), responses, failures); - } - - @Override - protected ADStatsNodeRequest newNodeRequest(ADStatsRequest request) { - return new ADStatsNodeRequest(request); - } - - @Override - protected ADStatsNodeResponse newNodeResponse(StreamInput in) throws IOException { - return new ADStatsNodeResponse(in); - } - - @Override - protected ADStatsNodeResponse nodeOperation(ADStatsNodeRequest request) { - return createADStatsNodeResponse(request.getADStatsRequest()); - } - - private ADStatsNodeResponse createADStatsNodeResponse(ADStatsRequest adStatsRequest) { - Map statValues = new HashMap<>(); + protected StatsNodeResponse createADStatsNodeResponse(StatsRequest adStatsRequest) { + Map statValues = super.createADStatsNodeResponse(adStatsRequest).getStatsMap(); Set statsToBeRetrieved = adStatsRequest.getStatsToBeRetrieved(); if (statsToBeRetrieved.contains(InternalStatNames.JVM_HEAP_USAGE.getName())) { @@ -120,12 +81,6 @@ private ADStatsNodeResponse createADStatsNodeResponse(ADStatsRequest adStatsRequ statValues.put(InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName(), assignedBatchTaskSlot); } - for (String statName : adStats.getNodeStats().keySet()) { - if (statsToBeRetrieved.contains(statName)) { - statValues.put(statName, adStats.getStats().get(statName).getValue()); - } - } - - return new ADStatsNodeResponse(clusterService.localNode(), statValues); + return new StatsNodeResponse(clusterService.localNode(), statValues); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java index f2b198d1c..f66d9e1ec 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java @@ -14,11 +14,11 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADTaskProfileAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/" + AD_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/" + AD_TASK; public static final ADTaskProfileAction INSTANCE = new ADTaskProfileAction(); private ADTaskProfileAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java index 6902d6de8..4bfbf7ca3 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java @@ -18,13 +18,13 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.transport.TransportService; public class ADTaskProfileTransportAction extends @@ -79,7 +79,7 @@ protected ADTaskProfileNodeResponse newNodeResponse(StreamInput in) throws IOExc @Override protected ADTaskProfileNodeResponse nodeOperation(ADTaskProfileNodeRequest request) { String remoteNodeId = request.getParentTask().getNodeId(); - Version remoteAdVersion = hashRing.getAdVersion(remoteNodeId); + Version remoteAdVersion = hashRing.getVersion(remoteNodeId); ADTaskProfile adTaskProfile = adTaskManager.getLocalADTaskProfilesByDetectorId(request.getId()); return new ADTaskProfileNodeResponse(clusterService.localNode(), adTaskProfile, remoteAdVersion); } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java index b11283181..b03180b70 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.JobResponse; -public class AnomalyDetectorJobAction extends ActionType { +public class AnomalyDetectorJobAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/jobmanagement"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/jobmanagement"; public static final AnomalyDetectorJobAction INSTANCE = new AnomalyDetectorJobAction(); private AnomalyDetectorJobAction() { - super(NAME, AnomalyDetectorJobResponse::new); + super(NAME, JobResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java deleted file mode 100644 index 3a62315a6..000000000 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.io.IOException; - -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.timeseries.model.DateRange; - -public class AnomalyDetectorJobRequest extends ActionRequest { - - private String detectorID; - private DateRange detectionDateRange; - private boolean historical; - private long seqNo; - private long primaryTerm; - private String rawPath; - - public AnomalyDetectorJobRequest(StreamInput in) throws IOException { - super(in); - detectorID = in.readString(); - seqNo = in.readLong(); - primaryTerm = in.readLong(); - rawPath = in.readString(); - if (in.readBoolean()) { - detectionDateRange = new DateRange(in); - } - historical = in.readBoolean(); - } - - public AnomalyDetectorJobRequest(String detectorID, long seqNo, long primaryTerm, String rawPath) { - this(detectorID, null, false, seqNo, primaryTerm, rawPath); - } - - /** - * Constructor function. - * - * The detectionDateRange and historical boolean can be passed in individually. - * The historical flag is for stopping detector, the detectionDateRange is for - * starting detector. It's ok if historical is true but detectionDateRange is - * null. - * - * @param detectorID detector identifier - * @param detectionDateRange detection date range - * @param historical historical analysis or not - * @param seqNo seq no - * @param primaryTerm primary term - * @param rawPath raw request path - */ - public AnomalyDetectorJobRequest( - String detectorID, - DateRange detectionDateRange, - boolean historical, - long seqNo, - long primaryTerm, - String rawPath - ) { - super(); - this.detectorID = detectorID; - this.detectionDateRange = detectionDateRange; - this.historical = historical; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.rawPath = rawPath; - } - - public String getDetectorID() { - return detectorID; - } - - public DateRange getDetectionDateRange() { - return detectionDateRange; - } - - public long getSeqNo() { - return seqNo; - } - - public long getPrimaryTerm() { - return primaryTerm; - } - - public String getRawPath() { - return rawPath; - } - - public boolean isHistorical() { - return historical; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(detectorID); - out.writeLong(seqNo); - out.writeLong(primaryTerm); - out.writeString(rawPath); - if (detectionDateRange != null) { - out.writeBoolean(true); - detectionDateRange.writeTo(out); - } else { - out.writeBoolean(false); - } - out.writeBoolean(historical); - } - - @Override - public ActionRequestValidationException validate() { - return null; - } -} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java index 5a81c43ae..358c9a062 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java @@ -15,46 +15,28 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_STOP_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.model.DateRange; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.transport.BaseJobTransportAction; import org.opensearch.transport.TransportService; -public class AnomalyDetectorJobTransportAction extends HandledTransportAction { - private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); - - private final Client client; - private final ClusterService clusterService; - private final Settings settings; - private final ADIndexManagement anomalyDetectionIndices; - private final NamedXContentRegistry xContentRegistry; - private volatile Boolean filterByEnabled; - private final ADTaskManager adTaskManager; - private final TransportService transportService; - private final ExecuteADResultResponseRecorder recorder; - +public class AnomalyDetectorJobTransportAction extends + BaseJobTransportAction { @Inject public AnomalyDetectorJobTransportAction( TransportService transportService, @@ -62,95 +44,23 @@ public AnomalyDetectorJobTransportAction( Client client, ClusterService clusterService, Settings settings, - ADIndexManagement anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder - ) { - super(AnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); - this.transportService = transportService; - this.client = client; - this.clusterService = clusterService; - this.settings = settings; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.xContentRegistry = xContentRegistry; - this.adTaskManager = adTaskManager; - filterByEnabled = AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.recorder = recorder; - } - - @Override - protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener actionListener) { - String detectorId = request.getDetectorID(); - DateRange detectionDateRange = request.getDetectionDateRange(); - boolean historical = request.isHistorical(); - long seqNo = request.getSeqNo(); - long primaryTerm = request.getPrimaryTerm(); - String rawPath = request.getRawPath(); - TimeValue requestTimeout = AD_REQUEST_TIMEOUT.get(settings); - String errorMessage = rawPath.endsWith(RestHandlerUtils.START_JOB) ? FAIL_TO_START_DETECTOR : FAIL_TO_STOP_DETECTOR; - ActionListener listener = wrapRestActionListener(actionListener, errorMessage); - - // By the time request reaches here, the user permissions are validated by Security plugin. - User user = getUserContext(client); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute( - user, - detectorId, - filterByEnabled, - listener, - (anomalyDetector) -> executeDetector( - listener, - detectorId, - detectionDateRange, - historical, - seqNo, - primaryTerm, - rawPath, - requestTimeout, - user, - context - ), - client, - clusterService, - xContentRegistry, - AnomalyDetector.class - ); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } - } - - private void executeDetector( - ActionListener listener, - String detectorId, - DateRange detectionDateRange, - boolean historical, - long seqNo, - long primaryTerm, - String rawPath, - TimeValue requestTimeout, - User user, - ThreadContext.StoredContext context + ADIndexJobActionHandler adIndexJobActionHandler ) { - IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + super( + transportService, + actionFilters, client, - anomalyDetectionIndices, - detectorId, - seqNo, - primaryTerm, - requestTimeout, + clusterService, + settings, xContentRegistry, - transportService, - adTaskManager, - recorder + AD_FILTER_BY_BACKEND_ROLES, + AnomalyDetectorJobAction.NAME, + AD_REQUEST_TIMEOUT, + FAIL_TO_START_DETECTOR, + FAIL_TO_STOP_DETECTOR, + AnomalyDetector.class, + adIndexJobActionHandler ); - if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { - adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); - } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { - adTaskManager.stopDetector(detectorId, historical, handler, user, transportService, listener); - } } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java index d61bd5822..36c8a2c9d 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class AnomalyResultAction extends ActionType { - // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/run"; + // External Action which used for public facing RestAPIs or actions we need to assume cx's role. + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/run"; public static final AnomalyResultAction INSTANCE = new AnomalyResultAction(); private AnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java index e6f788aeb..397271da0 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java @@ -26,56 +26,24 @@ import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.ResultRequest; -public class AnomalyResultRequest extends ActionRequest implements ToXContentObject { - private String adID; - // time range start and end. Unit: epoch milliseconds - private long start; - private long end; - +public class AnomalyResultRequest extends ResultRequest { public AnomalyResultRequest(StreamInput in) throws IOException { super(in); - adID = in.readString(); - start = in.readLong(); - end = in.readLong(); } public AnomalyResultRequest(String adID, long start, long end) { - super(); - this.adID = adID; - this.start = start; - this.end = end; - } - - public long getStart() { - return start; - } - - public long getEnd() { - return end; - } - - public String getAdID() { - return adID; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(adID); - out.writeLong(start); - out.writeLong(end); + super(adID, start, end); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { + if (Strings.isEmpty(configId)) { validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); } if (start <= 0 || end <= 0 || start > end) { @@ -90,7 +58,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(ADCommonName.ID_JSON_KEY, configId); builder.field(CommonName.START_JSON_KEY, start); builder.field(CommonName.END_JSON_KEY, end); builder.endObject(); diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java index 67113d3af..8708cb92a 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java @@ -17,6 +17,7 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; @@ -27,11 +28,12 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.transport.ResultResponse; -public class AnomalyResultResponse extends ActionResponse implements ToXContentObject { +public class AnomalyResultResponse extends ResultResponse { public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; public static final String CONFIDENCE_JSON_KEY = "confidence"; public static final String ANOMALY_SCORE_JSON_KEY = "anomalyScore"; @@ -49,18 +51,13 @@ public class AnomalyResultResponse extends ActionResponse implements ToXContentO private Double anomalyGrade; private Double confidence; - private Double anomalyScore; - private String error; - private List features; - private Long rcfTotalUpdates; - private Long detectorIntervalInMinutes; - private Boolean isHCDetector; private Integer relativeIndex; private double[] relevantAttribution; private double[] pastValues; private double[][] expectedValuesList; private double[] likelihoodOfValues; private Double threshold; + protected Double anomalyScore; // used when returning an error/exception or empty result public AnomalyResultResponse( @@ -68,7 +65,8 @@ public AnomalyResultResponse( String error, Long rcfTotalUpdates, Long detectorIntervalInMinutes, - Boolean isHCDetector + Boolean isHCDetector, + String taskId ) { this( Double.NaN, @@ -84,7 +82,8 @@ public AnomalyResultResponse( null, null, null, - Double.NaN + Double.NaN, + taskId ); } @@ -102,16 +101,13 @@ public AnomalyResultResponse( double[] pastValues, double[][] expectedValuesList, double[] likelihoodOfValues, - Double threshold + Double threshold, + String taskId ) { + super(features, error, rcfTotalUpdates, detectorIntervalInMinutes, isHCDetector, taskId); this.anomalyGrade = anomalyGrade; this.confidence = confidence; this.anomalyScore = anomalyScore; - this.features = features; - this.error = error; - this.rcfTotalUpdates = rcfTotalUpdates; - this.detectorIntervalInMinutes = detectorIntervalInMinutes; - this.isHCDetector = isHCDetector; this.relativeIndex = relativeIndex; this.relevantAttribution = currentTimeAttribution; this.pastValues = pastValues; @@ -134,8 +130,8 @@ public AnomalyResultResponse(StreamInput in) throws IOException { // new field added since AD 1.1 // Only send AnomalyResultRequest to local node, no need to change this part for BWC rcfTotalUpdates = in.readOptionalLong(); - detectorIntervalInMinutes = in.readOptionalLong(); - isHCDetector = in.readOptionalBoolean(); + configIntervalInMinutes = in.readOptionalLong(); + isHC = in.readOptionalBoolean(); this.relativeIndex = in.readOptionalInt(); @@ -171,16 +167,13 @@ public AnomalyResultResponse(StreamInput in) throws IOException { } this.threshold = in.readOptionalDouble(); + this.taskId = in.readOptionalString(); } public double getAnomalyGrade() { return anomalyGrade; } - public List getFeatures() { - return features; - } - public double getConfidence() { return confidence; } @@ -189,22 +182,6 @@ public double getAnomalyScore() { return anomalyScore; } - public String getError() { - return error; - } - - public Long getRcfTotalUpdates() { - return rcfTotalUpdates; - } - - public Long getIntervalInMinutes() { - return detectorIntervalInMinutes; - } - - public Boolean isHCDetector() { - return isHCDetector; - } - public Integer getRelativeIndex() { return relativeIndex; } @@ -240,8 +217,8 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalString(error); out.writeOptionalLong(rcfTotalUpdates); - out.writeOptionalLong(detectorIntervalInMinutes); - out.writeOptionalBoolean(isHCDetector); + out.writeOptionalLong(configIntervalInMinutes); + out.writeOptionalBoolean(isHC); out.writeOptionalInt(relativeIndex); @@ -280,6 +257,7 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalDouble(threshold); + out.writeOptionalString(taskId); } @Override @@ -295,13 +273,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); builder.field(RCF_TOTAL_UPDATES_JSON_KEY, rcfTotalUpdates); - builder.field(DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY, detectorIntervalInMinutes); + builder.field(DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY, configIntervalInMinutes); builder.field(RELATIVE_INDEX_FIELD_JSON_KEY, relativeIndex); builder.field(RELEVANT_ATTRIBUTION_FIELD_JSON_KEY, relevantAttribution); builder.field(PAST_VALUES_FIELD_JSON_KEY, pastValues); builder.field(EXPECTED_VAL_LIST_FIELD_JSON_KEY, expectedValuesList); builder.field(LIKELIHOOD_FIELD_JSON_KEY, likelihoodOfValues); builder.field(THRESHOLD_FIELD_JSON_KEY, threshold); + builder.field(CommonName.TASK_ID_FIELD, taskId); builder.endObject(); return builder; } @@ -325,7 +304,7 @@ public static AnomalyResultResponse fromActionResponse(final ActionResponse acti * * Convert AnomalyResultResponse to AnomalyResult * - * @param detectorId Detector Id + * @param configId Detector Id * @param dataStartInstant data start time * @param dataEndInstant data end time * @param executionStartInstant execution start time @@ -335,8 +314,9 @@ public static AnomalyResultResponse fromActionResponse(final ActionResponse acti * @param error Error * @return converted AnomalyResult */ - public AnomalyResult toAnomalyResult( - String detectorId, + @Override + public List toIndexableResults( + String configId, Instant dataStartInstant, Instant dataEndInstant, Instant executionStartInstant, @@ -347,30 +327,43 @@ public AnomalyResult toAnomalyResult( ) { // Detector interval in milliseconds long detectorIntervalMilli = Duration.between(dataStartInstant, dataEndInstant).toMillis(); - return AnomalyResult - .fromRawTRCFResult( - detectorId, - detectorIntervalMilli, - null, // real time results have no task id - anomalyScore, - anomalyGrade, - confidence, - features, - dataStartInstant, - dataEndInstant, - executionStartInstant, - executionEndInstant, - error, - Optional.empty(), - user, - schemaVersion, - null, // single-stream real-time has no model id - relevantAttribution, - relativeIndex, - pastValues, - expectedValuesList, - likelihoodOfValues, - threshold + return Collections + .singletonList( + AnomalyResult + .fromRawTRCFResult( + configId, + detectorIntervalMilli, + taskId, // real time results have no task id + anomalyScore, + anomalyGrade, + confidence, + features, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + Optional.empty(), + user, + schemaVersion, + null, // single-stream real-time has no model id + relevantAttribution, + relativeIndex, + pastValues, + expectedValuesList, + likelihoodOfValues, + threshold + ) ); } + + @Override + public boolean shouldSave() { + // skipping writing to the result index if not necessary + // For a single-stream analysis, the result is not useful if error is null + // and rcf score (e.g., thus anomaly grade/confidence/forecasts) is null. + // For a HC analysis, we don't need to save on the detector level. + // We return 0 or Double.NaN rcf score if there is no error. + return super.shouldSave() || anomalyScore > 0; + } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index 084db7f42..3c3d9ca17 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -11,139 +11,59 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_PAGE_SIZE; -import static org.opensearch.timeseries.constant.CommonMessages.INVALID_SEARCH_QUERY_MSG; - -import java.net.ConnectException; -import java.util.ArrayList; import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; -import org.opensearch.OpenSearchTimeoutException; -import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.ActionRequest; -import org.opensearch.action.search.SearchPhaseExecutionException; -import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.IndicesOptions; -import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.CompositeRetriever; -import org.opensearch.ad.feature.CompositeRetriever.PageIterator; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.lease.Releasable; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.NetworkExceptionHelper; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.node.NodeClosedException; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; -import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; -import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; -import org.opensearch.timeseries.model.Config; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.FeatureData; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.util.ExceptionUtil; -import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.transport.ResultProcessor; import org.opensearch.timeseries.util.SecurityClientUtil; -import org.opensearch.transport.ActionNotFoundTransportException; -import org.opensearch.transport.ConnectTransportException; -import org.opensearch.transport.NodeNotConnectedException; -import org.opensearch.transport.ReceiveTimeoutTransportException; -import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; public class AnomalyResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(AnomalyResultTransportAction.class); - static final String WAIT_FOR_THRESHOLD_ERR_MSG = "Exception in waiting for threshold result"; - static final String NODE_UNRESPONSIVE_ERR_MSG = "Model node is unresponsive. Mute node"; - static final String READ_WRITE_BLOCKED = "Cannot read/write due to global block."; - static final String INDEX_READ_BLOCKED = "Cannot read user index due to read block."; - static final String NULL_RESPONSE = "Received null response from"; - - static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: "; - static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes."; - - private final TransportService transportService; - private final NodeStateManager stateManager; - private final FeatureManager featureManager; - private final ModelManager modelManager; - private final HashRing hashRing; - private final TransportRequestOptions option; - private final ClusterService clusterService; - private final IndexNameExpressionResolver indexNameExpressionResolver; - private final ADStats adStats; - private final CircuitBreakerService adCircuitBreakerService; - private final ThreadPool threadPool; + private ADResultProcessor resultProcessor; private final Client client; - private final SecurityClientUtil clientUtil; - private final ADTaskManager adTaskManager; - + private CircuitBreakerService adCircuitBreakerService; // Cache HC detector id. This is used to count HC failure stats. We can tell a detector // is HC or not by checking if detector id exists in this field or not. Will add // detector id to this field when start to run realtime detection and remove detector // id once realtime detection done. private final Set hcDetectors; - private NamedXContentRegistry xContentRegistry; - private Settings settings; - // within an interval, how many percents are used to process requests. - // 1.0 means we use all of the detection interval to process requests. - // to ensure we don't block next interval, it is better to set it less than 1.0. - private final float intervalRatioForRequest; - private int maxEntitiesPerInterval; - private int pageSize; + private final ADStats adStats; + private final NodeStateManager nodeStateManager; @Inject public AnomalyResultTransportAction( @@ -152,9 +72,9 @@ public AnomalyResultTransportAction( Settings settings, Client client, SecurityClientUtil clientUtil, - NodeStateManager manager, + NodeStateManager nodeStateManager, FeatureManager featureManager, - ModelManager modelManager, + ADModelManager modelManager, HashRing hashRing, ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver, @@ -162,37 +82,35 @@ public AnomalyResultTransportAction( ADStats adStats, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager + ADTaskManager realTimeTaskManager ) { super(AnomalyResultAction.NAME, transportService, actionFilters, AnomalyResultRequest::new); - this.transportService = transportService; - this.settings = settings; + this.resultProcessor = new ADResultProcessor( + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityADResultAction.NAME, + StatNames.AD_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + adStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + AnomalyResultResponse.class, + featureManager, + modelManager + ); this.client = client; - this.clientUtil = clientUtil; - this.stateManager = manager; - this.featureManager = featureManager; - this.modelManager = modelManager; - this.hashRing = hashRing; - this.option = TransportRequestOptions - .builder() - .withType(TransportRequestOptions.Type.REG) - .withTimeout(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings)) - .build(); - this.clusterService = clusterService; - this.indexNameExpressionResolver = indexNameExpressionResolver; this.adCircuitBreakerService = adCircuitBreakerService; - this.adStats = adStats; - this.threadPool = threadPool; this.hcDetectors = new HashSet<>(); - this.xContentRegistry = xContentRegistry; - this.intervalRatioForRequest = TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS; - - this.maxEntitiesPerInterval = AD_MAX_ENTITIES_PER_QUERY.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_ENTITIES_PER_QUERY, it -> maxEntitiesPerInterval = it); - - this.pageSize = AD_PAGE_SIZE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_PAGE_SIZE, it -> pageSize = it); - this.adTaskManager = adTaskManager; + this.adStats = adStats; + this.nodeStateManager = nodeStateManager; } /** @@ -249,7 +167,7 @@ public AnomalyResultTransportAction( protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest); - String adID = request.getAdID(); + String adID = request.getConfigId(); ActionListener original = listener; listener = ActionListener.wrap(r -> { hcDetectors.remove(adID); @@ -278,864 +196,14 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< return; } try { - stateManager.getConfig(adID, AnalysisType.AD, onGetDetector(listener, adID, request)); + nodeStateManager + .getConfig(adID, AnalysisType.AD, resultProcessor.onGetConfig(listener, adID, request, Optional.of(hcDetectors))); } catch (Exception ex) { - handleExecuteException(ex, listener, adID); + ResultProcessor.handleExecuteException(ex, listener, adID); } } catch (Exception e) { LOG.error(e); listener.onFailure(e); } } - - /** - * didn't use ActionListener.wrap so that I can - * 1) use this to refer to the listener inside the listener - * 2) pass parameters using constructors - * - */ - class PageListener implements ActionListener { - private PageIterator pageIterator; - private String detectorId; - private long dataStartTime; - private long dataEndTime; - - PageListener(PageIterator pageIterator, String detectorId, long dataStartTime, long dataEndTime) { - this.pageIterator = pageIterator; - this.detectorId = detectorId; - this.dataStartTime = dataStartTime; - this.dataEndTime = dataEndTime; - } - - @Override - public void onResponse(CompositeRetriever.Page entityFeatures) { - if (pageIterator.hasNext()) { - pageIterator.next(this); - } - if (entityFeatures != null && false == entityFeatures.isEmpty()) { - // wrap expensive operation inside ad threadpool - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { - try { - - Set>> node2Entities = entityFeatures - .getResults() - .entrySet() - .stream() - .filter(e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).isPresent()) - .collect( - Collectors - .groupingBy( - // from entity name to its node - e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).get(), - Collectors.toMap(Entry::getKey, Entry::getValue) - ) - ) - .entrySet(); - - Iterator>> iterator = node2Entities.iterator(); - - while (iterator.hasNext()) { - Entry> entry = iterator.next(); - DiscoveryNode modelNode = entry.getKey(); - if (modelNode == null) { - iterator.remove(); - continue; - } - String modelNodeId = modelNode.getId(); - if (stateManager.isMuted(modelNodeId, detectorId)) { - LOG - .info( - String - .format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", modelNodeId, detectorId) - ); - iterator.remove(); - } - } - - final AtomicReference failure = new AtomicReference<>(); - node2Entities.stream().forEach(nodeEntity -> { - DiscoveryNode node = nodeEntity.getKey(); - transportService - .sendRequest( - node, - EntityResultAction.NAME, - new EntityResultRequest(detectorId, nodeEntity.getValue(), dataStartTime, dataEndTime), - option, - new ActionListenerResponseHandler<>( - new EntityResultListener(node.getId(), detectorId, failure), - AcknowledgedResponse::new, - ThreadPool.Names.SAME - ) - ); - }); - - } catch (Exception e) { - LOG.error("Unexpected exception", e); - handleException(e); - } - }); - } - } - - @Override - public void onFailure(Exception e) { - LOG.error("Unexpetected exception", e); - handleException(e); - } - - private void handleException(Exception e) { - Exception convertedException = convertedQueryFailureException(e, detectorId); - if (false == (convertedException instanceof TimeSeriesException)) { - Throwable cause = ExceptionsHelper.unwrapCause(convertedException); - convertedException = new InternalFailure(detectorId, cause); - } - stateManager.setException(detectorId, convertedException); - } - } - - private ActionListener> onGetDetector( - ActionListener listener, - String adID, - AnomalyResultRequest request - ) { - return ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new EndRunException(adID, "AnomalyDetector is not available.", true)); - return; - } - - AnomalyDetector anomalyDetector = (AnomalyDetector) detectorOptional.get(); - if (anomalyDetector.isHighCardinality()) { - hcDetectors.add(adID); - adStats.getStat(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName()).increment(); - } - - long delayMillis = Optional - .ofNullable((IntervalTimeConfiguration) anomalyDetector.getWindowDelay()) - .map(t -> t.toDuration().toMillis()) - .orElse(0L); - long dataStartTime = request.getStart() - delayMillis; - long dataEndTime = request.getEnd() - delayMillis; - - adTaskManager - .initRealtimeTaskCacheAndCleanupStaleCache( - adID, - anomalyDetector, - transportService, - ActionListener - .runAfter( - initRealtimeTaskCacheListener(adID), - () -> executeAnomalyDetection(listener, adID, request, anomalyDetector, dataStartTime, dataEndTime) - ) - ); - }, exception -> handleExecuteException(exception, listener, adID)); - } - - private ActionListener initRealtimeTaskCacheListener(String detectorId) { - return ActionListener.wrap(r -> { - if (r) { - LOG.debug("Realtime task cache initied for detector {}", detectorId); - } - }, e -> LOG.error("Failed to init realtime task cache for " + detectorId, e)); - } - - private void executeAnomalyDetection( - ActionListener listener, - String adID, - AnomalyResultRequest request, - AnomalyDetector anomalyDetector, - long dataStartTime, - long dataEndTime - ) { - // HC logic starts here - if (anomalyDetector.isHighCardinality()) { - Optional previousException = stateManager.fetchExceptionAndClear(adID); - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error(new ParameterizedMessage("Previous exception of [{}]", adID), exception); - if (exception instanceof EndRunException) { - EndRunException endRunException = (EndRunException) exception; - if (endRunException.isEndNow()) { - listener.onFailure(exception); - return; - } - } - } - - // assume request are in epoch milliseconds - long nextDetectionStartTime = request.getEnd() + (long) (anomalyDetector.getIntervalInMilliseconds() * intervalRatioForRequest); - - CompositeRetriever compositeRetriever = new CompositeRetriever( - dataStartTime, - dataEndTime, - anomalyDetector, - xContentRegistry, - client, - clientUtil, - nextDetectionStartTime, - settings, - maxEntitiesPerInterval, - pageSize, - indexNameExpressionResolver, - clusterService - ); - - PageIterator pageIterator = null; - - try { - pageIterator = compositeRetriever.iterator(); - } catch (Exception e) { - listener.onFailure(new EndRunException(anomalyDetector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, false)); - return; - } - - PageListener getEntityFeatureslistener = new PageListener(pageIterator, adID, dataStartTime, dataEndTime); - if (pageIterator.hasNext()) { - pageIterator.next(getEntityFeatureslistener); - } - - // We don't know when the pagination will not finish. To not - // block the following interval request to start, we return immediately. - // Pagination will stop itself when the time is up. - if (previousException.isPresent()) { - listener.onFailure(previousException.get()); - } else { - listener - .onResponse( - new AnomalyResultResponse(new ArrayList(), null, null, anomalyDetector.getIntervalInMinutes(), true) - ); - } - return; - } - - // HC logic ends and single entity logic starts here - // We are going to use only 1 model partition for a single stream detector. - // That's why we use 0 here. - String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(adID, 0); - Optional asRCFNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); - if (!asRCFNode.isPresent()) { - listener.onFailure(new InternalFailure(adID, "RCF model node is not available.")); - return; - } - - DiscoveryNode rcfNode = asRCFNode.get(); - - // we have already returned listener inside shouldStart method - if (!shouldStart(listener, adID, anomalyDetector, rcfNode.getId(), rcfModelID)) { - return; - } - - featureManager - .getCurrentFeatures( - anomalyDetector, - dataStartTime, - dataEndTime, - onFeatureResponseForSingleEntityDetector(adID, anomalyDetector, listener, rcfModelID, rcfNode, dataStartTime, dataEndTime) - ); - } - - // For single entity detector - private ActionListener onFeatureResponseForSingleEntityDetector( - String adID, - AnomalyDetector detector, - ActionListener listener, - String rcfModelId, - DiscoveryNode rcfNode, - long dataStartTime, - long dataEndTime - ) { - return ActionListener.wrap(featureOptional -> { - List featureInResponse = null; - if (featureOptional.getUnprocessedFeatures().isPresent()) { - featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); - } - - if (!featureOptional.getProcessedFeatures().isPresent()) { - Optional exception = coldStartIfNoCheckPoint(detector); - if (exception.isPresent()) { - listener.onFailure(exception.get()); - return; - } - - if (!featureOptional.getUnprocessedFeatures().isPresent()) { - // Feature not available is common when we have data holes. Respond empty response - // and don't log to avoid bloating our logs. - LOG.debug("No data in current detection window between {} and {} for {}", dataStartTime, dataEndTime, adID); - listener - .onResponse( - new AnomalyResultResponse( - new ArrayList(), - "No data in current detection window", - null, - null, - false - ) - ); - } else { - LOG.debug("Return at least current feature value between {} and {} for {}", dataStartTime, dataEndTime, adID); - listener - .onResponse( - new AnomalyResultResponse(featureInResponse, "No full shingle in current detection window", null, null, false) - ); - } - return; - } - - final AtomicReference failure = new AtomicReference(); - - LOG.info("Sending RCF request to {} for model {}", rcfNode.getId(), rcfModelId); - - RCFActionListener rcfListener = new RCFActionListener( - rcfModelId, - failure, - rcfNode.getId(), - detector, - listener, - featureInResponse, - adID - ); - - transportService - .sendRequest( - rcfNode, - RCFResultAction.NAME, - new RCFResultRequest(adID, rcfModelId, featureOptional.getProcessedFeatures().get()), - option, - new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) - ); - }, exception -> { handleQueryFailure(exception, listener, adID); }); - } - - private void handleQueryFailure(Exception exception, ActionListener listener, String adID) { - Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); - - if (convertedQueryFailureException instanceof EndRunException) { - // invalid feature query - listener.onFailure(convertedQueryFailureException); - } else { - handleExecuteException(convertedQueryFailureException, listener, adID); - } - } - - /** - * Convert a query related exception to EndRunException - * - * These query exception can happen during the starting phase of the OpenSearch - * process. Thus, set the stopNow parameter of these EndRunException to false - * and confirm the EndRunException is not a false positive. - * - * @param exception Exception - * @param adID detector Id - * @return the converted exception if the exception is query related - */ - private Exception convertedQueryFailureException(Exception exception, String adID) { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - return new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false).countedInStats(false); - } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { - // This is to catch invalid aggregation on wrong field type. For example, - // sum aggregation on text field. We should end detector run for such case. - return new EndRunException( - adID, - INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), - exception, - false - ).countedInStats(false); - } - - return exception; - } - - /** - * Verify failure of rcf or threshold models. If there is no model, trigger cold - * start. If there is an exception for the previous cold start of this detector, - * throw exception to the caller. - * - * @param failure object that may contain exceptions thrown - * @param detector detector object - * @return exception if AD job execution gets resource not found exception - * @throws Exception when the input failure is not a ResourceNotFoundException. - * List of exceptions we can throw - * 1. Exception from cold start: - * 1). InternalFailure due to - * a. OpenSearchTimeoutException thrown by putModelCheckpoint during cold start - * 2). EndRunException with endNow equal to false - * a. training data not available - * b. cold start cannot succeed - * c. invalid training data - * 3) EndRunException with endNow equal to true - * a. invalid search query - * 2. LimitExceededException from one of RCF model node when the total size of the models - * is more than X% of heap memory. - * 3. InternalFailure wrapping OpenSearchTimeoutException inside caused by - * RCF/Threshold model node failing to get checkpoint to restore model before timeout. - */ - private Exception coldStartIfNoModel(AtomicReference failure, AnomalyDetector detector) throws Exception { - Exception exp = failure.get(); - if (exp == null) { - return null; - } - - // return exceptions like LimitExceededException to caller - if (!(exp instanceof ResourceNotFoundException)) { - return exp; - } - - // fetch previous cold start exception - String adID = detector.getId(); - final Optional previousException = stateManager.fetchExceptionAndClear(adID); - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); - if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { - return exception; - } - } - LOG.info("Trigger cold start for {}", detector.getId()); - coldStart(detector); - return previousException.orElse(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); - } - - private void findException(Throwable cause, String adID, AtomicReference failure, String nodeId) { - if (cause == null) { - LOG.error(new ParameterizedMessage("Null input exception")); - return; - } - if (cause instanceof Error) { - // we cannot do anything with Error. - LOG.error(new ParameterizedMessage("Error during prediction for {}: ", adID), cause); - return; - } - - Exception causeException = (Exception) cause; - - if (causeException instanceof TimeSeriesException) { - failure.set(causeException); - } else if (causeException instanceof NotSerializableExceptionWrapper) { - // we only expect this happens on AD exceptions - Optional actualException = NotSerializedExceptionName - .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, adID); - if (actualException.isPresent()) { - TimeSeriesException adException = actualException.get(); - failure.set(adException); - if (adException instanceof ResourceNotFoundException) { - // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF - // 1.0 - // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node - // after consecutive failures. - stateManager.addPressure(nodeId, adID); - } - } else { - // some unexpected bugs occur while predicting anomaly - failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); - } - } else if (causeException instanceof IndexNotFoundException - && causeException.getMessage().contains(ADCommonName.CHECKPOINT_INDEX_NAME)) { - // checkpoint index does not exist - // ResourceNotFoundException will trigger cold start later - failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); - } else if (causeException instanceof OpenSearchTimeoutException) { - // we can have OpenSearchTimeoutException when a node tries to load RCF or - // threshold model - failure.set(new InternalFailure(adID, causeException)); - } else if (causeException instanceof IllegalArgumentException) { - // we can have IllegalArgumentException when a model is corrupted - failure.set(new InternalFailure(adID, causeException)); - } else { - // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. - // NoShardAvailableActionException) while predicting anomaly - failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); - } - } - - void handleExecuteException(Exception ex, ActionListener listener, String adID) { - if (ex instanceof ClientException) { - listener.onFailure(ex); - } else if (ex instanceof TimeSeriesException) { - listener.onFailure(new InternalFailure((TimeSeriesException) ex)); - } else { - Throwable cause = ExceptionsHelper.unwrapCause(ex); - listener.onFailure(new InternalFailure(adID, cause)); - } - } - - private boolean invalidQuery(SearchPhaseExecutionException ex) { - // If all shards return bad request and failure cause is IllegalArgumentException, we - // consider the feature query is invalid and will not count the error in failure stats. - for (ShardSearchFailure failure : ex.shardFailures()) { - if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { - return false; - } - } - return true; - } - - // For single entity detector - class RCFActionListener implements ActionListener { - private String modelID; - private AtomicReference failure; - private String rcfNodeID; - private AnomalyDetector detector; - private ActionListener listener; - private List featureInResponse; - private final String adID; - - RCFActionListener( - String modelID, - AtomicReference failure, - String rcfNodeID, - AnomalyDetector detector, - ActionListener listener, - List features, - String adID - ) { - this.modelID = modelID; - this.failure = failure; - this.rcfNodeID = rcfNodeID; - this.detector = detector; - this.listener = listener; - this.featureInResponse = features; - this.adID = adID; - } - - @Override - public void onResponse(RCFResultResponse response) { - try { - stateManager.resetBackpressureCounter(rcfNodeID, adID); - if (response != null) { - listener - .onResponse( - new AnomalyResultResponse( - response.getAnomalyGrade(), - response.getConfidence(), - response.getRCFScore(), - featureInResponse, - null, - response.getTotalUpdates(), - detector.getIntervalInMinutes(), - false, - response.getRelativeIndex(), - response.getAttribution(), - response.getPastValues(), - response.getExpectedValuesList(), - response.getLikelihoodOfValues(), - response.getThreshold() - ) - ); - } else { - LOG.warn(NULL_RESPONSE + " {} for {}", modelID, rcfNodeID); - listener.onFailure(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); - } - } catch (Exception ex) { - LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); - handleExecuteException(ex, listener, adID); - } - } - - @Override - public void onFailure(Exception e) { - try { - handlePredictionFailure(e, adID, rcfNodeID, failure); - Exception exception = coldStartIfNoModel(failure, detector); - if (exception != null) { - listener.onFailure(exception); - } else { - listener.onFailure(new InternalFailure(adID, "Node connection problem or unexpected exception")); - } - } catch (Exception ex) { - LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); - handleExecuteException(ex, listener, adID); - } - } - } - - /** - * Handle a prediction failure. Possibly (i.e., we don't always need to do that) - * convert the exception to a form that AD can recognize and handle and sets the - * input failure reference to the converted exception. - * - * @param e prediction exception - * @param adID Detector Id - * @param nodeID Node Id - * @param failure Parameter to receive the possibly converted function for the - * caller to deal with - */ - private void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference failure) { - LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); - if (e == null) { - return; - } - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (hasConnectionIssue(cause)) { - handleConnectionException(nodeID, adID); - } else { - findException(cause, adID, failure, nodeID); - } - } - - /** - * Check if the input exception indicates connection issues. - * During blue-green deployment, we may see ActionNotFoundTransportException. - * Count that as connection issue and isolate that node if it continues to happen. - * - * @param e exception - * @return true if we get disconnected from the node or the node is not in the - * right state (being closed) or transport request times out (sent from TimeoutHandler.run) - */ - private boolean hasConnectionIssue(Throwable e) { - return e instanceof ConnectTransportException - || e instanceof NodeClosedException - || e instanceof ReceiveTimeoutTransportException - || e instanceof NodeNotConnectedException - || e instanceof ConnectException - || NetworkExceptionHelper.isCloseConnectionException(e) - || e instanceof ActionNotFoundTransportException; - } - - private void handleConnectionException(String node, String detectorId) { - final DiscoveryNodes nodes = clusterService.state().nodes(); - if (!nodes.nodeExists(node)) { - hashRing.buildCirclesForRealtimeAD(); - return; - } - // rebuilding is not done or node is unresponsive - stateManager.addPressure(node, detectorId); - } - - /** - * Since we need to read from customer index and write to anomaly result index, - * we need to make sure we can read and write. - * - * @param state Cluster state - * @return whether we have global block or not - */ - private boolean checkGlobalBlock(ClusterState state) { - return state.blocks().globalBlockedException(ClusterBlockLevel.READ) != null - || state.blocks().globalBlockedException(ClusterBlockLevel.WRITE) != null; - } - - /** - * Similar to checkGlobalBlock, we check block on the indices level. - * - * @param state Cluster state - * @param level block level - * @param indices the indices on which to check block - * @return whether any of the index has block on the level. - */ - private boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { - // the original index might be an index expression with wildcards like "log*", - // so we need to expand the expression to concrete index name - String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); - - return state.blocks().indicesBlockedException(level, concreteIndices) != null; - } - - /** - * Check if we should start anomaly prediction. - * - * @param listener listener to respond back to AnomalyResultRequest. - * @param adID detector ID - * @param detector detector instance corresponds to adID - * @param rcfNodeId the rcf model hosting node ID for adID - * @param rcfModelID the rcf model ID for adID - * @return if we can start anomaly prediction. - */ - private boolean shouldStart( - ActionListener listener, - String adID, - AnomalyDetector detector, - String rcfNodeId, - String rcfModelID - ) { - ClusterState state = clusterService.state(); - if (checkGlobalBlock(state)) { - listener.onFailure(new InternalFailure(adID, READ_WRITE_BLOCKED)); - return false; - } - - if (stateManager.isMuted(rcfNodeId, adID)) { - listener - .onFailure( - new InternalFailure( - adID, - String.format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for rcf model %s", rcfNodeId, rcfModelID) - ) - ); - return false; - } - - if (checkIndicesBlocked(state, ClusterBlockLevel.READ, detector.getIndices().toArray(new String[0]))) { - listener.onFailure(new InternalFailure(adID, INDEX_READ_BLOCKED)); - return false; - } - - return true; - } - - private void coldStart(AnomalyDetector detector) { - String detectorId = detector.getId(); - - // If last cold start is not finished, we don't trigger another one - if (stateManager.isColdStartRunning(detectorId)) { - return; - } - - final Releasable coldStartFinishingCallback = stateManager.markColdStartRunning(detectorId); - - ActionListener> listener = ActionListener.wrap(trainingData -> { - if (trainingData.isPresent()) { - double[][] dataPoints = trainingData.get(); - - ActionListener trainModelListener = ActionListener - .wrap(res -> { LOG.info("Succeeded in training {}", detectorId); }, exception -> { - if (exception instanceof TimeSeriesException) { - // e.g., partitioned model exceeds memory limit - stateManager.setException(detectorId, exception); - } else if (exception instanceof IllegalArgumentException) { - // IllegalArgumentException due to invalid training data - stateManager - .setException(detectorId, new EndRunException(detectorId, "Invalid training data", exception, false)); - } else if (exception instanceof OpenSearchTimeoutException) { - stateManager - .setException( - detectorId, - new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) - ); - } else { - stateManager - .setException(detectorId, new EndRunException(detectorId, "Error while training model", exception, false)); - } - }); - - modelManager - .trainModel( - detector, - dataPoints, - new ThreadedActionListener<>( - LOG, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - trainModelListener, - false - ) - ); - } else { - stateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); - } - }, exception -> { - if (exception instanceof OpenSearchTimeoutException) { - stateManager.setException(detectorId, new InternalFailure(detectorId, "Time out while getting training data", exception)); - } else if (exception instanceof TimeSeriesException) { - // e.g., Invalid search query - stateManager.setException(detectorId, exception); - } else { - stateManager.setException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); - } - }); - - final ActionListener> listenerWithReleaseCallback = ActionListener - .runAfter(listener, coldStartFinishingCallback::close); - - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> featureManager - .getColdStartData( - detector, - new ThreadedActionListener<>( - LOG, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - listenerWithReleaseCallback, - false - ) - ) - ); - } - - /** - * Check if checkpoint for an detector exists or not. If not and previous - * run is not EndRunException whose endNow is true, trigger cold start. - * @param detector detector object - * @return previous cold start exception - */ - private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { - String detectorId = detector.getId(); - - Optional previousException = stateManager.fetchExceptionAndClear(detectorId); - - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error(new ParameterizedMessage("Previous exception of {}:", detectorId), exception); - if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { - return previousException; - } - } - - stateManager.getDetectorCheckpoint(detectorId, ActionListener.wrap(checkpointExists -> { - if (!checkpointExists) { - LOG.info("Trigger cold start for {}", detectorId); - coldStart(detector); - } - }, exception -> { - Throwable cause = ExceptionsHelper.unwrapCause(exception); - if (cause instanceof IndexNotFoundException) { - LOG.info("Trigger cold start for {}", detectorId); - coldStart(detector); - } else { - String errorMsg = String.format(Locale.ROOT, "Fail to get checkpoint state for %s", detectorId); - LOG.error(errorMsg, exception); - stateManager.setException(detectorId, new TimeSeriesException(errorMsg, exception)); - } - })); - - return previousException; - } - - class EntityResultListener implements ActionListener { - private String nodeId; - private final String adID; - private AtomicReference failure; - - EntityResultListener(String nodeId, String adID, AtomicReference failure) { - this.nodeId = nodeId; - this.adID = adID; - this.failure = failure; - } - - @Override - public void onResponse(AcknowledgedResponse response) { - try { - if (response.isAcknowledged() == false) { - LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); - stateManager.addPressure(nodeId, adID); - } else { - stateManager.resetBackpressureCounter(nodeId, adID); - } - } catch (Exception ex) { - LOG.error("Unexpected exception: {} for {}", ex, adID); - handleException(ex); - } - } - - @Override - public void onFailure(Exception e) { - try { - // e.g., we have connection issues with all of the nodes while restarting clusters - LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); - - handleException(e); - - } catch (Exception ex) { - LOG.error("Unexpected exception: {} for {}", ex, adID); - handleException(ex); - } - } - - private void handleException(Exception e) { - handlePredictionFailure(e, adID, nodeId, failure); - if (failure.get() != null) { - stateManager.setException(adID, failure.get()); - } - } - } } diff --git a/src/main/java/org/opensearch/ad/transport/CronAction.java b/src/main/java/org/opensearch/ad/transport/CronAction.java index 1e64a0f45..91a5aa2cb 100644 --- a/src/main/java/org/opensearch/ad/transport/CronAction.java +++ b/src/main/java/org/opensearch/ad/transport/CronAction.java @@ -12,11 +12,12 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.CronResponse; public class CronAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "cron"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "cron"; public static final CronAction INSTANCE = new CronAction(); private CronAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java b/src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java similarity index 55% rename from src/main/java/org/opensearch/ad/transport/DeleteModelAction.java rename to src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java index 3af6982b0..c4eeef176 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java @@ -12,14 +12,15 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.DeleteModelResponse; -public class DeleteModelAction extends ActionType { +public class DeleteADModelAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; - public static final DeleteModelAction INSTANCE = new DeleteModelAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; + public static final DeleteADModelAction INSTANCE = new DeleteADModelAction(); - private DeleteModelAction() { + private DeleteADModelAction() { super(NAME, DeleteModelResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java new file mode 100644 index 000000000..ec44ae9f9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java @@ -0,0 +1,104 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.BaseDeleteModelTransportAction; +import org.opensearch.timeseries.transport.DeleteModelNodeRequest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class DeleteADModelTransportAction extends + BaseDeleteModelTransportAction { + private static final Logger LOG = LogManager.getLogger(DeleteADModelTransportAction.class); + private ADModelManager modelManager; + private FeatureManager featureManager; + + @Inject + public DeleteADModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cache, + ADTaskCacheManager adTaskCacheManager, + ADEntityColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + cache, + adTaskCacheManager, + coldStarter, + DeleteADModelAction.NAME + ); + this.modelManager = modelManager; + this.featureManager = featureManager; + } + + /** + * + * Delete checkpoint document (including both RCF and thresholding model), in-memory models, + * buffered shingle data, transport state, and anomaly result + * + * @param request delete request + * @return delete response including local node Id. + */ + @Override + protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) { + super.nodeOperation(request); + String adID = request.getConfigID(); + + // delete in-memory models and model checkpoint + modelManager + .clear( + adID, + ActionListener + .wrap( + r -> LOG.info("Deleted model for [{}] with response [{}] ", adID, r), + e -> LOG.error("Fail to delete model for " + adID, e) + ) + ); + + // delete buffered shingle data + featureManager.clear(adID); + + LOG.info("Finished deleting ad models for {}", adID); + return new DeleteModelNodeResponse(clusterService.localNode()); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java index 75dc34638..70d655507 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class DeleteAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/delete"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/delete"; public static final DeleteAnomalyDetectorAction INSTANCE = new DeleteAnomalyDetectorAction(); private DeleteAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java index 221a935bc..33124125d 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java @@ -11,58 +11,28 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_DETECTOR; -import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import java.io.IOException; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.Job; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.transport.BaseDeleteConfigTransportAction; import org.opensearch.transport.TransportService; -public class DeleteAnomalyDetectorTransportAction extends HandledTransportAction { - - private static final Logger LOG = LogManager.getLogger(DeleteAnomalyDetectorTransportAction.class); - private final Client client; - private final ClusterService clusterService; - private final TransportService transportService; - private NamedXContentRegistry xContentRegistry; - private final ADTaskManager adTaskManager; - private volatile Boolean filterByEnabled; +public class DeleteAnomalyDetectorTransportAction extends + BaseDeleteConfigTransportAction { @Inject public DeleteAnomalyDetectorTransportAction( @@ -72,153 +42,24 @@ public DeleteAnomalyDetectorTransportAction( ClusterService clusterService, Settings settings, NamedXContentRegistry xContentRegistry, + NodeStateManager nodeStateManager, ADTaskManager adTaskManager ) { - super(DeleteAnomalyDetectorAction.NAME, transportService, actionFilters, DeleteAnomalyDetectorRequest::new); - this.transportService = transportService; - this.client = client; - this.clusterService = clusterService; - this.xContentRegistry = xContentRegistry; - this.adTaskManager = adTaskManager; - filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - } - - @Override - protected void doExecute(Task task, DeleteAnomalyDetectorRequest request, ActionListener actionListener) { - String detectorId = request.getDetectorID(); - LOG.info("Delete anomaly detector job {}", detectorId); - User user = getUserContext(client); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_DELETE_DETECTOR); - // By the time request reaches here, the user permissions are validated by Security plugin. - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute( - user, - detectorId, - filterByEnabled, - listener, - (anomalyDetector) -> adTaskManager.getDetector(detectorId, detector -> { - if (!detector.isPresent()) { - // In a mixed cluster, if delete detector request routes to node running AD1.0, then it will - // not delete detector tasks. User can re-delete these deleted detector after cluster upgraded, - // in that case, the detector is not present. - LOG.info("Can't find anomaly detector {}", detectorId); - adTaskManager.deleteADTasks(detectorId, () -> deleteAnomalyDetectorJobDoc(detectorId, listener), listener); - return; - } - // Check if there is realtime job or historical analysis task running. If none of these running, we - // can delete the detector. - getDetectorJob(detectorId, listener, () -> { - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); - } else { - adTaskManager.deleteADTasks(detectorId, () -> deleteAnomalyDetectorJobDoc(detectorId, listener), listener); - } - }, transportService, true, listener); - }); - }, listener), - client, - clusterService, - xContentRegistry, - AnomalyDetector.class - ); - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); - } - } - - private void deleteAnomalyDetectorJobDoc(String detectorId, ActionListener listener) { - LOG.info("Delete anomaly detector job {}", detectorId); - DeleteRequest deleteRequest = new DeleteRequest(CommonName.JOB_INDEX, detectorId) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.delete(deleteRequest, ActionListener.wrap(response -> { - if (response.getResult() == DocWriteResponse.Result.DELETED || response.getResult() == DocWriteResponse.Result.NOT_FOUND) { - deleteDetectorStateDoc(detectorId, listener); - } else { - String message = "Fail to delete anomaly detector job " + detectorId; - LOG.error(message); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, exception -> { - LOG.error("Failed to delete AD job for " + detectorId, exception); - if (exception instanceof IndexNotFoundException) { - deleteDetectorStateDoc(detectorId, listener); - } else { - LOG.error("Failed to delete anomaly detector job", exception); - listener.onFailure(exception); - } - })); - } - - private void deleteDetectorStateDoc(String detectorId, ActionListener listener) { - LOG.info("Delete detector info {}", detectorId); - DeleteRequest deleteRequest = new DeleteRequest(ADCommonName.DETECTION_STATE_INDEX, detectorId); - client.delete(deleteRequest, ActionListener.wrap(response -> { - // whether deleted state doc or not, continue as state doc may not exist - deleteAnomalyDetectorDoc(detectorId, listener); - }, exception -> { - if (exception instanceof IndexNotFoundException) { - deleteAnomalyDetectorDoc(detectorId, listener); - } else { - LOG.error("Failed to delete detector state", exception); - listener.onFailure(exception); - } - })); - } - - private void deleteAnomalyDetectorDoc(String detectorId, ActionListener listener) { - LOG.info("Delete anomaly detector {}", detectorId); - DeleteRequest deleteRequest = new DeleteRequest(CommonName.CONFIG_INDEX, detectorId) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.delete(deleteRequest, new ActionListener() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - listener.onResponse(deleteResponse); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }); - } - - private void getDetectorJob(String detectorId, ActionListener listener, ExecutorFunction function) { - if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { - GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - client.get(request, ActionListener.wrap(response -> onGetAdJobResponseForWrite(response, listener, function), exception -> { - LOG.error("Fail to get anomaly detector job: " + detectorId, exception); - listener.onFailure(exception); - })); - } else { - function.execute(); - } - } - - private void onGetAdJobResponseForWrite(GetResponse response, ActionListener listener, ExecutorFunction function) - throws IOException { - if (response.isExists()) { - String adJobId = response.getId(); - if (adJobId != null) { - // check if AD job is running on the detector, if yes, we can't delete the detector - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Job adJob = Job.parse(parser); - if (adJob.isEnabled()) { - listener.onFailure(new OpenSearchStatusException("Detector job is running: " + adJobId, RestStatus.BAD_REQUEST)); - return; - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + adJobId; - LOG.error(message, e); - } - } - } - function.execute(); + super( + transportService, + actionFilters, + client, + clusterService, + settings, + xContentRegistry, + nodeStateManager, + adTaskManager, + DeleteAnomalyDetectorAction.NAME, + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, + AnalysisType.AD, + ADCommonName.DETECTION_STATE_INDEX, + AnomalyDetector.class, + ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES + ); } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java index ae9de4c95..84065dbb7 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java @@ -12,12 +12,12 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.index.reindex.BulkByScrollResponse; public class DeleteAnomalyResultsAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "results/delete"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "results/delete"; public static final DeleteAnomalyResultsAction INSTANCE = new DeleteAnomalyResultsAction(); private DeleteAnomalyResultsAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java index e2db9ed4a..8c218ac64 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java @@ -13,8 +13,6 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_AD_RESULT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; import org.apache.logging.log4j.LogManager; @@ -32,6 +30,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.transport.TransportService; public class DeleteAnomalyResultsTransportAction extends HandledTransportAction { @@ -61,7 +60,7 @@ protected void doExecute(Task task, DeleteByQueryRequest request, ActionListener } public void delete(DeleteByQueryRequest request, ActionListener listener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { validateRole(request, user, listener); } catch (Exception e) { @@ -79,7 +78,7 @@ private void validateRole(DeleteByQueryRequest request, User user, ActionListene } else { // Security is enabled and backend role filter is enabled try { - addUserBackendRolesFilter(user, request.getSearchRequest().source()); + ParseUtils.addUserBackendRolesFilter(user, request.getSearchRequest().source()); client.execute(DeleteByQueryAction.INSTANCE, request, listener); } catch (Exception e) { listener.onFailure(e); diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultAction.java b/src/main/java/org/opensearch/ad/transport/EntityADResultAction.java similarity index 62% rename from src/main/java/org/opensearch/ad/transport/EntityResultAction.java rename to src/main/java/org/opensearch/ad/transport/EntityADResultAction.java index c519858b4..f17c23416 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityADResultAction.java @@ -13,14 +13,14 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; -public class EntityResultAction extends ActionType { +public class EntityADResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; - public static final EntityResultAction INSTANCE = new EntityResultAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; + public static final EntityADResultAction INSTANCE = new EntityADResultAction(); - private EntityResultAction() { + private EntityADResultAction() { super(NAME, AcknowledgedResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java new file mode 100644 index 000000000..a6de9effa --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java @@ -0,0 +1,176 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.EntityResultProcessor; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Entry-point for HCAD workflow. We have created multiple queues for + * coordinating the workflow. The overrall workflow is: 1. We store as many + * frequently used entity models in a cache as allowed by the memory limit (10% + * heap). If an entity feature is a hit, we use the in-memory model to detect + * anomalies and record results using the result write queue. 2. If an entity + * feature is a miss, we check if there is free memory or any other entity's + * model can be evacuated. An in-memory entity's frequency may be lower compared + * to the cache miss entity. If that's the case, we replace the lower frequency + * entity's model with the higher frequency entity's model. To load the higher + * frequency entity's model, we first check if a model exists on disk by sending + * a checkpoint read queue request. If there is a checkpoint, we load it to + * memory, perform detection, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the + * checkpoint write queue. 3. We also have the cold entity queue configured for + * cold entities, and the model training and inference are connected by serial + * juxtaposition to limit resource usage. + */ +public class EntityADResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityADResultTransportAction.class); + private CircuitBreakerService adCircuitBreakerService; + private CacheProvider cache; + private final NodeStateManager stateManager; + private ThreadPool threadPool; + private EntityResultProcessor intervalDataProcessor; + + private final ADCacheProvider entityCache; + private final ADModelManager manager; + private final ADStats timeSeriesStats; + private final ADColdStartWorker entityColdStartWorker; + private final ADCheckpointReadWorker checkpointReadQueue; + private final ADColdEntityWorker coldEntityQueue; + private final ADSaveResultStrategy adSaveResultStategy; + + @Inject + public EntityADResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ADModelManager manager, + CircuitBreakerService adCircuitBreakerService, + ADCacheProvider entityCache, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + ADResultWriteWorker resultWriteQueue, + ADCheckpointReadWorker checkpointReadQueue, + ADColdEntityWorker coldEntityQueue, + ThreadPool threadPool, + ADColdStartWorker entityColdStartWorker, + ADStats timeSeriesStats, + ADSaveResultStrategy adSaveResultStategy + ) { + super(EntityADResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); + this.adCircuitBreakerService = adCircuitBreakerService; + this.cache = entityCache; + this.stateManager = stateManager; + this.threadPool = threadPool; + + this.entityCache = entityCache; + this.manager = manager; + this.timeSeriesStats = timeSeriesStats; + this.entityColdStartWorker = entityColdStartWorker; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.adSaveResultStategy = adSaveResultStategy; + this.intervalDataProcessor = null; + } + + @Override + protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { + if (adCircuitBreakerService.isOpen()) { + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String detectorId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", detectorId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, detectorId); + } + + this.intervalDataProcessor = new EntityResultProcessor<>( + entityCache, + manager, + timeSeriesStats, + entityColdStartWorker, + checkpointReadQueue, + coldEntityQueue, + adSaveResultStategy, + StatNames.AD_MODEL_CORRUTPION_COUNT + ); + + stateManager + .getConfig( + detectorId, + request.getAnalysisType(), + intervalDataProcessor.onGetConfig(listener, detectorId, request, previousException, request.getAnalysisType()) + ); + } catch (Exception exception) { + LOG.error("fail to get entity's anomaly grade", exception); + listener.onFailure(exception); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java deleted file mode 100644 index d17ce7137..000000000 --- a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java +++ /dev/null @@ -1,356 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.time.Instant; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndex; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.EntityFeatureRequest; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.ad.ratelimit.ResultWriteWorker; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.AnalysisType; -import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.breaker.CircuitBreakerService; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.LimitExceededException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.model.Config; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.util.ExceptionUtil; -import org.opensearch.timeseries.util.ParseUtils; -import org.opensearch.transport.TransportService; - -/** - * Entry-point for HCAD workflow. We have created multiple queues for coordinating - * the workflow. The overrall workflow is: - * 1. We store as many frequently used entity models in a cache as allowed by the - * memory limit (10% heap). If an entity feature is a hit, we use the in-memory model - * to detect anomalies and record results using the result write queue. - * 2. If an entity feature is a miss, we check if there is free memory or any other - * entity's model can be evacuated. An in-memory entity's frequency may be lower - * compared to the cache miss entity. If that's the case, we replace the lower - * frequency entity's model with the higher frequency entity's model. To load the - * higher frequency entity's model, we first check if a model exists on disk by - * sending a checkpoint read queue request. If there is a checkpoint, we load it - * to memory, perform detection, and save the result using the result write queue. - * Otherwise, we enqueue a cold start request to the cold start queue for model - * training. If training is successful, we save the learned model via the checkpoint - * write queue. - * 3. We also have the cold entity queue configured for cold entities, and the model - * training and inference are connected by serial juxtaposition to limit resource usage. - */ -public class EntityResultTransportAction extends HandledTransportAction { - - private static final Logger LOG = LogManager.getLogger(EntityResultTransportAction.class); - private ModelManager modelManager; - private CircuitBreakerService adCircuitBreakerService; - private CacheProvider cache; - private final NodeStateManager stateManager; - private ADIndexManagement indexUtil; - private ResultWriteWorker resultWriteQueue; - private CheckpointReadWorker checkpointReadQueue; - private ColdEntityWorker coldEntityQueue; - private ThreadPool threadPool; - private EntityColdStartWorker entityColdStartWorker; - private ADStats adStats; - - @Inject - public EntityResultTransportAction( - ActionFilters actionFilters, - TransportService transportService, - ModelManager manager, - CircuitBreakerService adCircuitBreakerService, - CacheProvider entityCache, - NodeStateManager stateManager, - ADIndexManagement indexUtil, - ResultWriteWorker resultWriteQueue, - CheckpointReadWorker checkpointReadQueue, - ColdEntityWorker coldEntityQueue, - ThreadPool threadPool, - EntityColdStartWorker entityColdStartWorker, - ADStats adStats - ) { - super(EntityResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); - this.modelManager = manager; - this.adCircuitBreakerService = adCircuitBreakerService; - this.cache = entityCache; - this.stateManager = stateManager; - this.indexUtil = indexUtil; - this.resultWriteQueue = resultWriteQueue; - this.checkpointReadQueue = checkpointReadQueue; - this.coldEntityQueue = coldEntityQueue; - this.threadPool = threadPool; - this.entityColdStartWorker = entityColdStartWorker; - this.adStats = adStats; - } - - @Override - protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { - if (adCircuitBreakerService.isOpen()) { - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); - listener.onFailure(new LimitExceededException(request.getId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); - return; - } - - try { - String detectorId = request.getId(); - - Optional previousException = stateManager.fetchExceptionAndClear(detectorId); - - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error("Previous exception of {}: {}", detectorId, exception); - if (exception instanceof EndRunException) { - EndRunException endRunException = (EndRunException) exception; - if (endRunException.isEndNow()) { - listener.onFailure(exception); - return; - } - } - - listener = ExceptionUtil.wrapListener(listener, exception, detectorId); - } - - stateManager.getConfig(detectorId, AnalysisType.AD, onGetDetector(listener, detectorId, request, previousException)); - } catch (Exception exception) { - LOG.error("fail to get entity's anomaly grade", exception); - listener.onFailure(exception); - } - } - - private ActionListener> onGetDetector( - ActionListener listener, - String detectorId, - EntityResultRequest request, - Optional prevException - ) { - return ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); - return; - } - - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - - if (request.getEntities() == null) { - listener.onFailure(new EndRunException(detectorId, "Fail to get any entities from request.", false)); - return; - } - - Instant executionStartTime = Instant.now(); - Map cacheMissEntities = new HashMap<>(); - for (Entry entityEntry : request.getEntities().entrySet()) { - Entity categoricalValues = entityEntry.getKey(); - - if (isEntityFromOldNodeMsg(categoricalValues) - && detector.getCategoryFields() != null - && detector.getCategoryFields().size() == 1) { - Map attrValues = categoricalValues.getAttributes(); - // handle a request from a version before OpenSearch 1.1. - categoricalValues = Entity - .createSingleAttributeEntity(detector.getCategoryFields().get(0), attrValues.get(ADCommonName.EMPTY_FIELD)); - } - - Optional modelIdOptional = categoricalValues.getModelId(detectorId); - if (false == modelIdOptional.isPresent()) { - continue; - } - - String modelId = modelIdOptional.get(); - double[] datapoint = entityEntry.getValue(); - ModelState entityModel = cache.get().get(modelId, detector); - if (entityModel == null) { - // cache miss - cacheMissEntities.put(categoricalValues, datapoint); - continue; - } - try { - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(datapoint, entityModel, modelId, categoricalValues, detector.getShingleSize()); - // result.getRcfScore() = 0 means the model is not initialized - // result.getGrade() = 0 means it is not an anomaly - // So many OpenSearchRejectedExecutionException if we write no matter what - if (result.getRcfScore() > 0) { - List resultsToSave = result - .toIndexableResults( - detector, - Instant.ofEpochMilli(request.getStart()), - Instant.ofEpochMilli(request.getEnd()), - executionStartTime, - Instant.now(), - ParseUtils.getFeatureData(datapoint, detector), - Optional.ofNullable(categoricalValues), - indexUtil.getSchemaVersion(ADIndex.RESULT), - modelId, - null, - null - ); - for (AnomalyResult r : resultsToSave) { - resultWriteQueue - .put( - new ResultWriteRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, - r, - detector.getCustomResultIndex() - ) - ); - } - } - } catch (IllegalArgumentException e) { - // fail to score likely due to model corruption. Re-cold start to recover. - LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); - adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); - cache.get().removeEntityModel(detectorId, modelId); - entityColdStartWorker - .put( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - RequestPriority.MEDIUM, - categoricalValues, - datapoint, - request.getStart() - ) - ); - } - } - - // split hot and cold entities - Pair, List> hotColdEntities = cache - .get() - .selectUpdateCandidate(cacheMissEntities.keySet(), detectorId, detector); - - List hotEntityRequests = new ArrayList<>(); - List coldEntityRequests = new ArrayList<>(); - - for (Entity hotEntity : hotColdEntities.getLeft()) { - double[] hotEntityValue = cacheMissEntities.get(hotEntity); - if (hotEntityValue == null) { - LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); - continue; - } - hotEntityRequests - .add( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - // hot entities has MEDIUM priority - RequestPriority.MEDIUM, - hotEntity, - hotEntityValue, - request.getStart() - ) - ); - } - - for (Entity coldEntity : hotColdEntities.getRight()) { - double[] coldEntityValue = cacheMissEntities.get(coldEntity); - if (coldEntityValue == null) { - LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); - continue; - } - coldEntityRequests - .add( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - // cold entities has LOW priority - RequestPriority.LOW, - coldEntity, - coldEntityValue, - request.getStart() - ) - ); - } - - checkpointReadQueue.putAll(hotEntityRequests); - coldEntityQueue.putAll(coldEntityRequests); - - // respond back - if (prevException.isPresent()) { - listener.onFailure(prevException.get()); - } else { - listener.onResponse(new AcknowledgedResponse(true)); - } - }, exception -> { - LOG - .error( - new ParameterizedMessage( - "fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", - detectorId, - request.getStart(), - request.getEnd() - ), - exception - ); - listener.onFailure(exception); - }); - } - - /** - * Whether the received entity comes from an node that doesn't support multi-category fields. - * This can happen during rolling-upgrade or blue/green deployment. - * - * Specifically, when receiving an EntityResultRequest from an incompatible node, - * EntityResultRequest(StreamInput in) gets an String that represents an entity. - * But Entity class requires both an category field name and value. Since we - * don't have access to detector config in EntityResultRequest(StreamInput in), - * we put CommonName.EMPTY_FIELD as the placeholder. In this method, - * we use the same CommonName.EMPTY_FIELD to check if the deserialized entity - * comes from an incompatible node. If it is, we will add the field name back - * as EntityResultTranportAction has access to the detector config object. - * - * @param categoricalValues deserialized Entity from inbound message. - * @return Whether the received entity comes from an node that doesn't support multi-category fields. - */ - private boolean isEntityFromOldNodeMsg(Entity categoricalValues) { - Map attrValues = categoricalValues.getAttributes(); - return (attrValues != null && attrValues.containsKey(ADCommonName.EMPTY_FIELD)); - } -} diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java index 309714cc8..43c62eed3 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java @@ -14,14 +14,15 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.JobResponse; -public class ForwardADTaskAction extends ActionType { +public class ForwardADTaskAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK + "/forward"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK + "/forward"; public static final ForwardADTaskAction INSTANCE = new ForwardADTaskAction(); private ForwardADTaskAction() { - super(NAME, AnomalyDetectorJobResponse::new); + super(NAME, JobResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java index cab9a4f65..a0591e052 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java @@ -11,10 +11,6 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.model.ADTask.ERROR_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; - import java.util.Arrays; import java.util.List; @@ -23,10 +19,10 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.common.inject.Inject; @@ -35,17 +31,21 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; -public class ForwardADTaskTransportAction extends HandledTransportAction { +public class ForwardADTaskTransportAction extends HandledTransportAction { private final Logger logger = LogManager.getLogger(ForwardADTaskTransportAction.class); private final TransportService transportService; private final ADTaskManager adTaskManager; private final ADTaskCacheManager adTaskCacheManager; + private final ADIndexJobActionHandler indexJobHander; // ========================================================= // Fields below contains cache for realtime AD on coordinating @@ -64,7 +64,8 @@ public ForwardADTaskTransportAction( ADTaskManager adTaskManager, ADTaskCacheManager adTaskCacheManager, FeatureManager featureManager, - NodeStateManager stateManager + NodeStateManager stateManager, + ADIndexJobActionHandler indexJobHander ) { super(ForwardADTaskAction.NAME, transportService, actionFilters, ForwardADTaskRequest::new); this.adTaskManager = adTaskManager; @@ -72,10 +73,11 @@ public ForwardADTaskTransportAction( this.adTaskCacheManager = adTaskCacheManager; this.featureManager = featureManager; this.stateManager = stateManager; + this.indexJobHander = indexJobHander; } @Override - protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener listener) { + protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener listener) { ADTaskAction adTaskAction = request.getAdTaskAction(); AnomalyDetector detector = request.getDetector(); DateRange detectionDateRange = request.getDetectionDateRange(); @@ -107,7 +109,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener case START: // Start historical analysis for detector logger.debug("Received START action for detector {}", detectorId); - adTaskManager.startDetector(detector, detectionDateRange, user, transportService, ActionListener.wrap(r -> { + indexJobHander.startConfig(detector, detectionDateRange, user, transportService, ActionListener.wrap(r -> { adTaskCacheManager.setDetectorTaskSlots(detector.getId(), availableTaskSlots); listener.onResponse(r); }, e -> listener.onFailure(e))); @@ -120,8 +122,10 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener if (!adTaskCacheManager.hasEntity(detectorId)) { adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); logger.info("Historical HC detector done, will remove from cache, detector id:{}", detectorId); - listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK)); - TaskState state = !adTask.isEntityTask() && adTask.getError() != null ? TaskState.FAILED : TaskState.FINISHED; + listener.onResponse(new JobResponse(detectorId)); + TaskState state = !adTask.isHistoricalEntityTask() && adTask.getError() != null + ? TaskState.FAILED + : TaskState.FINISHED; adTaskManager.setHCDetectorTaskDone(adTask, state, listener); } else { logger.debug("Run next entity for detector " + detectorId); @@ -132,11 +136,11 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTask.getParentTaskId(), ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, TaskState.RUNNING.name(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, adTaskManager.hcDetectorProgress(detectorId), - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, adTask.getError() != null ? adTask.getError() : "" ) ); @@ -154,7 +158,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener case PUSH_BACK_ENTITY: logger.debug("Received PUSH_BACK_ENTITY action for detector {}, task {}", detectorId, adTask.getTaskId()); // Push back entity to pending entities queue and run next entity. - if (adTask.isEntityTask()) { // AD task must be entity level task. + if (adTask.isHistoricalEntityTask()) { // AD task must be entity level task. adTaskCacheManager.removeRunningEntity(detectorId, entityValue); if (adTaskManager.isRetryableError(adTask.getError()) && !adTaskCacheManager.exceedRetryLimit(adTask.getConfigId(), adTask.getTaskId())) { @@ -176,7 +180,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener logger.debug("After scale down, only 1 task slot reserved for detector {}, run next entity", detectorId); adTaskManager.runNextEntityForHCADHistorical(adTask, transportService, listener); } - listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.ACCEPTED)); + listener.onResponse(new JobResponse(adTask.getTaskId())); } } else { logger.warn("Can only push back entity task"); @@ -193,7 +197,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTaskCacheManager.scaleUpDetectorTaskSlots(detectorId, newSlots); } } - listener.onResponse(new AnomalyDetectorJobResponse(detector.getId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(detector.getId())); break; case CANCEL: logger.debug("Received CANCEL action for detector {}", detectorId); @@ -203,10 +207,10 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener if (detector.isHighCardinality()) { adTaskCacheManager.clearPendingEntities(detectorId); adTaskCacheManager.removeRunningEntity(detectorId, entityValue); - if (!adTaskCacheManager.hasEntity(detectorId) || !adTask.isEntityTask()) { + if (!adTaskCacheManager.hasEntity(detectorId) || !adTask.isHistoricalEntityTask()) { adTaskManager.setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } - listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(adTask.getTaskId())); } else { listener.onFailure(new IllegalArgumentException("Only support cancel HC now")); } @@ -227,7 +231,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener for (String entity : staleRunningEntities) { adTaskManager.removeStaleRunningEntity(adTask, entity, transportService, listener); } - listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(adTask.getTaskId())); break; case CLEAN_CACHE: boolean historicalTask = adTask.isHistoricalTask(); @@ -249,7 +253,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener stateManager.clear(detectorId); featureManager.clear(detectorId); } - listener.onResponse(new AnomalyDetectorJobResponse(detector.getId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(detector.getId())); break; default: listener.onFailure(new OpenSearchStatusException("Unsupported AD task action " + adTaskAction, RestStatus.BAD_REQUEST)); diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java index c4232047d..c740ed24e 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class GetAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detectors/get"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detectors/get"; public static final GetAnomalyDetectorAction INSTANCE = new GetAnomalyDetectorAction(); private GetAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index 1636e6181..38f06f55e 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -16,13 +16,13 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.EntityProfile; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.EntityProfile; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.RestHandlerUtils; diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 3b040c9e1..0bae4a5ce 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -11,88 +11,43 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_DETECTOR; -import static org.opensearch.ad.model.ADTaskType.ALL_DETECTOR_TASK_TYPES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.get.MultiGetItemResponse; -import org.opensearch.action.get.MultiGetRequest; -import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.ADEntityProfileRunner; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.AnomalyDetectorProfileRunner; -import org.opensearch.ad.EntityProfileRunner; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.CheckedConsumer; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.Name; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; -import com.google.common.collect.Sets; - -public class GetAnomalyDetectorTransportAction extends HandledTransportAction { - - private static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); +public class GetAnomalyDetectorTransportAction extends + BaseGetConfigTransportAction { - private final ClusterService clusterService; - private final Client client; - private final SecurityClientUtil clientUtil; - private final Set allProfileTypeStrs; - private final Set allProfileTypes; - private final Set defaultDetectorProfileTypes; - private final Set allEntityProfileTypeStrs; - private final Set allEntityProfileTypes; - private final Set defaultEntityProfileTypes; - private final NamedXContentRegistry xContentRegistry; - private final DiscoveryNodeFilterer nodeFilter; - private final TransportService transportService; - private volatile Boolean filterByEnabled; - private final ADTaskManager adTaskManager; + public static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); @Inject public GetAnomalyDetectorTransportAction( @@ -104,321 +59,105 @@ public GetAnomalyDetectorTransportAction( SecurityClientUtil clientUtil, Settings settings, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager + ADTaskManager adTaskManager, + ADTaskProfileRunner adTaskProfileRunner ) { - super(GetAnomalyDetectorAction.NAME, transportService, actionFilters, GetAnomalyDetectorRequest::new); - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; - List allProfiles = Arrays.asList(DetectorProfileName.values()); - this.allProfileTypes = EnumSet.copyOf(allProfiles); - this.allProfileTypeStrs = getProfileListStrs(allProfiles); - List defaultProfiles = Arrays.asList(DetectorProfileName.ERROR, DetectorProfileName.STATE); - this.defaultDetectorProfileTypes = new HashSet(defaultProfiles); - - List allEntityProfiles = Arrays.asList(EntityProfileName.values()); - this.allEntityProfileTypes = EnumSet.copyOf(allEntityProfiles); - this.allEntityProfileTypeStrs = getProfileListStrs(allEntityProfiles); - List defaultEntityProfiles = Arrays.asList(EntityProfileName.STATE); - this.defaultEntityProfileTypes = new HashSet(defaultEntityProfiles); - - this.xContentRegistry = xContentRegistry; - this.nodeFilter = nodeFilter; - filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.transportService = transportService; - this.adTaskManager = adTaskManager; + super( + transportService, + nodeFilter, + actionFilters, + clusterService, + client, + clientUtil, + settings, + xContentRegistry, + adTaskManager, + GetAnomalyDetectorAction.NAME, + AnomalyDetector.class, + AnomalyDetector.PARSE_FIELD_NAME, + ADTaskType.ALL_DETECTOR_TASK_TYPES, + ADTaskType.REALTIME_HC_DETECTOR.name(), + ADTaskType.REALTIME_SINGLE_ENTITY.name(), + ADTaskType.HISTORICAL_HC_DETECTOR.name(), + ADTaskType.HISTORICAL_SINGLE_ENTITY.name(), + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, + adTaskProfileRunner + ); } @Override - protected void doExecute(Task task, GetAnomalyDetectorRequest request, ActionListener actionListener) { - String detectorID = request.getDetectorID(); - User user = getUserContext(client); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_DETECTOR); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute( - user, - detectorID, - filterByEnabled, - listener, - (anomalyDetector) -> getExecute(request, listener), - client, - clusterService, - xContentRegistry, - AnomalyDetector.class - ); - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); + protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) { + if (tasks.containsKey(ADTaskType.HISTORICAL.name())) { + historicalAdTask = Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name())); } } - protected void getExecute(GetAnomalyDetectorRequest request, ActionListener listener) { - String detectorID = request.getDetectorID(); - String typesStr = request.getTypeStr(); - String rawPath = request.getRawPath(); - Entity entity = request.getEntity(); - boolean all = request.isAll(); - boolean returnJob = request.isReturnJob(); - boolean returnTask = request.isReturnTask(); - - try { - if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { - if (entity != null) { - Set entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); - EntityProfileRunner profileRunner = new EntityProfileRunner( - client, - clientUtil, - xContentRegistry, - TimeSeriesSettings.NUM_MIN_SAMPLES - ); - profileRunner.profile(detectorID, entity, entityProfilesToCollect, ActionListener.wrap(profile -> { - listener - .onResponse( - new GetAnomalyDetectorResponse( - 0, - null, - 0, - 0, - null, - null, - false, - null, - null, - false, - null, - null, - profile, - true - ) - ); - }, e -> listener.onFailure(e))); - } else { - Set profilesToCollect = getProfilesToCollect(typesStr, all); - AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner( - client, - clientUtil, - xContentRegistry, - nodeFilter, - TimeSeriesSettings.NUM_MIN_SAMPLES, - transportService, - adTaskManager - ); - profileRunner.profile(detectorID, getProfileActionListener(listener), profilesToCollect); - } - } else { - if (returnTask) { - adTaskManager.getAndExecuteOnLatestADTasks(detectorID, null, null, ALL_DETECTOR_TASK_TYPES, (taskList) -> { - Optional realtimeAdTask = Optional.empty(); - Optional historicalAdTask = Optional.empty(); - - if (taskList != null && taskList.size() > 0) { - Map adTasks = new HashMap<>(); - List duplicateAdTasks = new ArrayList<>(); - for (ADTask task : taskList) { - if (adTasks.containsKey(task.getTaskType())) { - LOG - .info( - "Found duplicate latest task of detector {}, task id: {}, task type: {}", - detectorID, - task.getTaskType(), - task.getTaskId() - ); - duplicateAdTasks.add(task); - continue; - } - adTasks.put(task.getTaskType(), task); - } - if (duplicateAdTasks.size() > 0) { - adTaskManager.resetLatestFlagAsFalse(duplicateAdTasks); - } - - if (adTasks.containsKey(ADTaskType.REALTIME_HC_DETECTOR.name())) { - realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_HC_DETECTOR.name())); - } else if (adTasks.containsKey(ADTaskType.REALTIME_SINGLE_ENTITY.name())) { - realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_SINGLE_ENTITY.name())); - } - if (adTasks.containsKey(ADTaskType.HISTORICAL_HC_DETECTOR.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_HC_DETECTOR.name())); - } else if (adTasks.containsKey(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())); - } else if (adTasks.containsKey(ADTaskType.HISTORICAL.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL.name())); - } - } - getDetectorAndJob(detectorID, returnJob, returnTask, realtimeAdTask, historicalAdTask, listener); - }, transportService, true, 2, listener); - } else { - getDetectorAndJob(detectorID, returnJob, returnTask, Optional.empty(), Optional.empty(), listener); - } - } - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); - } - } - - private void getDetectorAndJob( - String detectorID, + @Override + protected GetAnomalyDetectorResponse createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + AnomalyDetector config, + Job job, boolean returnJob, + Optional realtimeTask, + Optional historicalTask, boolean returnTask, - Optional realtimeAdTask, - Optional historicalAdTask, - ActionListener listener + RestStatus restStatus, + DetectorProfile detectorProfile, + EntityProfile entityProfile, + boolean profileResponse ) { - MultiGetRequest.Item adItem = new MultiGetRequest.Item(CommonName.CONFIG_INDEX, detectorID); - MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); - if (returnJob) { - MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(CommonName.JOB_INDEX, detectorID); - multiGetRequest.add(adJobItem); - } - client.multiGet(multiGetRequest, onMultiGetResponse(listener, returnJob, returnTask, realtimeAdTask, historicalAdTask, detectorID)); + return new GetAnomalyDetectorResponse( + version, + id, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask.orElse(null), + historicalTask.orElse(null), + returnTask, + RestStatus.OK, + detectorProfile, + entityProfile, + profileResponse + ); } - private ActionListener onMultiGetResponse( - ActionListener listener, - boolean returnJob, - boolean returnTask, - Optional realtimeAdTask, - Optional historicalAdTask, - String detectorId + @Override + protected ADEntityProfileRunner createEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples ) { - return new ActionListener() { - @Override - public void onResponse(MultiGetResponse multiGetResponse) { - MultiGetItemResponse[] responses = multiGetResponse.getResponses(); - AnomalyDetector detector = null; - Job adJob = null; - String id = null; - long version = 0; - long seqNo = 0; - long primaryTerm = 0; - - for (MultiGetItemResponse response : responses) { - if (CommonName.CONFIG_INDEX.equals(response.getIndex())) { - if (response.getResponse() == null || !response.getResponse().isExists()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - id = response.getId(); - version = response.getResponse().getVersion(); - primaryTerm = response.getResponse().getPrimaryTerm(); - seqNo = response.getResponse().getSeqNo(); - if (!response.getResponse().isSourceEmpty()) { - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - detector = parser.namedObject(AnomalyDetector.class, AnomalyDetector.PARSE_FIELD_NAME, null); - } catch (Exception e) { - String message = "Failed to parse detector job " + detectorId; - listener.onFailure(buildInternalServerErrorResponse(e, message)); - return; - } - } - } - - if (CommonName.JOB_INDEX.equals(response.getIndex())) { - if (response.getResponse() != null - && response.getResponse().isExists() - && !response.getResponse().isSourceEmpty()) { - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - adJob = Job.parse(parser); - } catch (Exception e) { - String message = "Failed to parse detector job " + detectorId; - listener.onFailure(buildInternalServerErrorResponse(e, message)); - return; - } - } - } - } - listener - .onResponse( - new GetAnomalyDetectorResponse( - version, - id, - primaryTerm, - seqNo, - detector, - adJob, - returnJob, - realtimeAdTask.orElse(null), - historicalAdTask.orElse(null), - returnTask, - RestStatus.OK, - null, - null, - false - ) - ); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }; - } - - private ActionListener getProfileActionListener(ActionListener listener) { - return ActionListener.wrap(new CheckedConsumer() { - @Override - public void accept(DetectorProfile profile) throws Exception { - listener - .onResponse( - new GetAnomalyDetectorResponse(0, null, 0, 0, null, null, false, null, null, false, null, profile, null, true) - ); - } - }, exception -> { listener.onFailure(exception); }); - } - - private OpenSearchStatusException buildInternalServerErrorResponse(Exception e, String errorMsg) { - LOG.error(errorMsg, e); - return new OpenSearchStatusException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR); - } - - /** - * - * @param typesStr a list of input profile types separated by comma - * @param all whether we should return all profile in the response - * @return profiles to collect for a detector - */ - private Set getProfilesToCollect(String typesStr, boolean all) { - if (all) { - return this.allProfileTypes; - } else if (Strings.isEmpty(typesStr)) { - return this.defaultDetectorProfileTypes; - } else { - // Filter out unsupported types - Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); - return DetectorProfileName.getNames(Sets.intersection(allProfileTypeStrs, typesInRequest)); - } + return new ADEntityProfileRunner(client, clientUtil, xContentRegistry, TimeSeriesSettings.NUM_MIN_SAMPLES); } - /** - * - * @param typesStr a list of input profile types separated by comma - * @param all whether we should return all profile in the response - * @return profiles to collect for an entity - */ - private Set getEntityProfilesToCollect(String typesStr, boolean all) { - if (all) { - return this.allEntityProfileTypes; - } else if (Strings.isEmpty(typesStr)) { - return this.defaultEntityProfileTypes; - } else { - // Filter out unsupported types - Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); - return EntityProfileName.getNames(Sets.intersection(allEntityProfileTypeStrs, typesInRequest)); - } + @Override + protected AnomalyDetectorProfileRunner createProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + ADTaskManager taskManager, + ADTaskProfileRunner taskProfileRunner + ) { + return new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + TimeSeriesSettings.NUM_MIN_SAMPLES, + transportService, + taskManager, + taskProfileRunner + ); } - private Set getProfileListStrs(List profileList) { - return profileList.stream().map(profile -> profile.getName()).collect(Collectors.toSet()); - } } diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java index 9ee038336..56103dfc9 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class IndexAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/write"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/write"; public static final IndexAnomalyDetectorAction INSTANCE = new IndexAnomalyDetectorAction(); private IndexAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java index 572e847f9..6a4bb6d1d 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java @@ -34,6 +34,9 @@ public class IndexAnomalyDetectorRequest extends ActionRequest { private Integer maxSingleEntityAnomalyDetectors; private Integer maxMultiEntityAnomalyDetectors; private Integer maxAnomalyFeatures; + // added during refactoring for forecasting. It is fine we add a new field + // since the request is handled by the same node. + private Integer maxCategoricalFields; public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { super(in); @@ -47,6 +50,7 @@ public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { maxSingleEntityAnomalyDetectors = in.readInt(); maxMultiEntityAnomalyDetectors = in.readInt(); maxAnomalyFeatures = in.readInt(); + maxCategoricalFields = in.readInt(); } public IndexAnomalyDetectorRequest( @@ -59,7 +63,8 @@ public IndexAnomalyDetectorRequest( TimeValue requestTimeout, Integer maxSingleEntityAnomalyDetectors, Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures + Integer maxAnomalyFeatures, + Integer maxCategoricalFields ) { super(); this.detectorID = detectorID; @@ -72,6 +77,7 @@ public IndexAnomalyDetectorRequest( this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; this.maxAnomalyFeatures = maxAnomalyFeatures; + this.maxCategoricalFields = maxCategoricalFields; } public String getDetectorID() { @@ -114,6 +120,10 @@ public Integer getMaxAnomalyFeatures() { return maxAnomalyFeatures; } + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -127,6 +137,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxSingleEntityAnomalyDetectors); out.writeInt(maxMultiEntityAnomalyDetectors); out.writeInt(maxAnomalyFeatures); + out.writeInt(maxCategoricalFields); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java index ac0e560b1..5d9b69910 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java @@ -16,7 +16,6 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; import static org.opensearch.timeseries.util.ParseUtils.getConfig; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; import java.util.List; @@ -46,8 +45,10 @@ import org.opensearch.rest.RestRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -93,7 +94,7 @@ public IndexAnomalyDetectorTransportAction( @Override protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionListener actionListener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); String detectorId = request.getDetectorID(); RestRequest.Method method = request.getMethod(); String errorMessage = method == RestRequest.Method.PUT ? FAIL_TO_UPDATE_DETECTOR : FAIL_TO_CREATE_DETECTOR; @@ -116,8 +117,12 @@ private void resolveUserAndExecute( try { // Check if user has backend roles // When filter by is enabled, block users creating/updating detectors who do not have backend roles. - if (filterByEnabled && !checkFilterByBackendRoles(requestedUser, listener)) { - return; + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new TimeSeriesException(error)); + return; + } } if (method == RestRequest.Method.PUT) { // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to @@ -164,6 +169,7 @@ protected void adExecute( Integer maxSingleEntityAnomalyDetectors = request.getMaxSingleEntityAnomalyDetectors(); Integer maxMultiEntityAnomalyDetectors = request.getMaxMultiEntityAnomalyDetectors(); Integer maxAnomalyFeatures = request.getMaxAnomalyFeatures(); + Integer maxCategoricalFields = request.getMaxCategoricalFields(); storedContext.restore(); checkIndicesAndExecute(detector.getIndices(), () -> { @@ -175,7 +181,6 @@ protected void adExecute( client, clientUtil, transportService, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -186,6 +191,7 @@ protected void adExecute( maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry, detectorUser, @@ -193,7 +199,7 @@ protected void adExecute( searchFeatureDao, settings ); - indexAnomalyDetectorActionHandler.start(); + indexAnomalyDetectorActionHandler.start(listener); }, listener); } diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java index c90ecc446..5ae8d6c35 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class PreviewAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/preview"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/preview"; public static final PreviewAnomalyDetectorAction INSTANCE = new PreviewAnomalyDetectorAction(); private PreviewAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java index 5f6c6c9d3..ef82c43b2 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java @@ -16,7 +16,6 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; @@ -56,6 +55,7 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; @@ -103,7 +103,7 @@ protected void doExecute( ActionListener actionListener ) { String detectorId = request.getId(); - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_PREVIEW_DETECTOR); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { resolveUserAndExecute( diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java index 147ff74cb..b38a088eb 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class RCFPollingAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcfpolling"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "rcfpolling"; public static final RCFPollingAction INSTANCE = new RCFPollingAction(); private RCFPollingAction() { diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java index a8bd64603..c7783cb8f 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java @@ -19,8 +19,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -30,6 +29,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.transport.TransportException; @@ -48,7 +48,7 @@ public class RCFPollingTransportAction extends HandledTransportAction rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); + Optional rcfNode = hashRing.getOwningNodeWithSameLocalVersionForRealtime(rcfModelID); if (!rcfNode.isPresent()) { listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); return; diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultAction.java b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java index 3480e880a..f551f97df 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class RCFResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcf/result"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "rcf/result"; public static final RCFResultAction INSTANCE = new RCFResultAction(); private RCFResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java index d7df181bb..59ca12965 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java @@ -20,14 +20,14 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.stats.ADStats; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.stats.StatNames; @@ -36,7 +36,7 @@ public class RCFResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(RCFResultTransportAction.class); - private ModelManager manager; + private ADModelManager manager; private CircuitBreakerService adCircuitBreakerService; private HashRing hashRing; private ADStats adStats; @@ -45,7 +45,7 @@ public class RCFResultTransportAction extends HandledTransportAction { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; public static final SearchADTasksAction INSTANCE = new SearchADTasksAction(); private SearchADTasksAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java index c15ece9ab..90ae6cede 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/search"; public static final SearchAnomalyDetectorAction INSTANCE = new SearchAnomalyDetectorAction(); private SearchAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java index 3f4f7c2fc..50f3b60d4 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.SearchConfigInfoResponse; -public class SearchAnomalyDetectorInfoAction extends ActionType { +public class SearchAnomalyDetectorInfoAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/info"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/info"; public static final SearchAnomalyDetectorInfoAction INSTANCE = new SearchAnomalyDetectorInfoAction(); private SearchAnomalyDetectorInfoAction() { - super(NAME, SearchAnomalyDetectorInfoResponse::new); + super(NAME, SearchConfigInfoResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java index 7db017d3d..c83ac9ebd 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java @@ -11,34 +11,14 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_DETECTOR_INFO; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.action.ActionListener; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.TermsQueryBuilder; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.transport.BaseSearchConfigInfoTransportAction; import org.opensearch.transport.TransportService; -public class SearchAnomalyDetectorInfoTransportAction extends - HandledTransportAction { - private static final Logger LOG = LogManager.getLogger(SearchAnomalyDetectorInfoTransportAction.class); - private final Client client; - private final ClusterService clusterService; +public class SearchAnomalyDetectorInfoTransportAction extends BaseSearchConfigInfoTransportAction { @Inject public SearchAnomalyDetectorInfoTransportAction( @@ -47,80 +27,6 @@ public SearchAnomalyDetectorInfoTransportAction( Client client, ClusterService clusterService ) { - super(SearchAnomalyDetectorInfoAction.NAME, transportService, actionFilters, SearchAnomalyDetectorInfoRequest::new); - this.client = client; - this.clusterService = clusterService; - } - - @Override - protected void doExecute( - Task task, - SearchAnomalyDetectorInfoRequest request, - ActionListener actionListener - ) { - String name = request.getName(); - String rawPath = request.getRawPath(); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_DETECTOR_INFO); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - SearchRequest searchRequest = new SearchRequest().indices(CommonName.CONFIG_INDEX); - if (rawPath.endsWith(RestHandlerUtils.COUNT)) { - // Count detectors - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchRequest.source(searchSourceBuilder); - client.search(searchRequest, new ActionListener() { - - @Override - public void onResponse(SearchResponse searchResponse) { - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse( - searchResponse.getHits().getTotalHits().value, - false - ); - listener.onResponse(response); - } - - @Override - public void onFailure(Exception e) { - if (e.getClass() == IndexNotFoundException.class) { - // Anomaly Detectors index does not exist - // Could be that user is creating first detector - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, false); - listener.onResponse(response); - } else { - listener.onFailure(e); - } - } - }); - } else { - // Match name with existing detectors - TermsQueryBuilder query = QueryBuilders.termsQuery("name.keyword", name); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - searchRequest.source(searchSourceBuilder); - client.search(searchRequest, new ActionListener() { - - @Override - public void onResponse(SearchResponse searchResponse) { - boolean nameExists = false; - nameExists = searchResponse.getHits().getTotalHits().value > 0; - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, nameExists); - listener.onResponse(response); - } - - @Override - public void onFailure(Exception e) { - if (e.getClass() == IndexNotFoundException.class) { - // Anomaly Detectors index does not exist - // Could be that user is creating first detector - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, false); - listener.onResponse(response); - } else { - listener.onFailure(e); - } - } - }); - } - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); - } + super(transportService, actionFilters, client, SearchAnomalyDetectorInfoAction.NAME); } } diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java index 7e0178393..e2a5969bd 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchAnomalyResultAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "result/search"; public static final SearchAnomalyResultAction INSTANCE = new SearchAnomalyResultAction(); private SearchAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java index ee89c4179..8956eeb1d 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchTopAnomalyResultAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/topAnomalies"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "result/topAnomalies"; public static final SearchTopAnomalyResultAction INSTANCE = new SearchTopAnomalyResultAction(); private SearchTopAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java index 82a1a02a3..afe3c4729 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java @@ -46,7 +46,6 @@ import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.script.Script; -import org.opensearch.script.ScriptType; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; @@ -64,10 +63,10 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.QueryUtil; import org.opensearch.transport.TransportService; -import com.google.common.collect.ImmutableMap; - /** * Transport action to fetch top anomaly results for some HC detector. Generates a * query based on user input to fetch aggregated entity results. @@ -219,7 +218,7 @@ public SearchTopAnomalyResultTransportAction( @Override protected void doExecute(Task task, SearchTopAnomalyResultRequest request, ActionListener listener) { - GetAnomalyDetectorRequest getAdRequest = new GetAnomalyDetectorRequest( + GetConfigRequest getAdRequest = new GetConfigRequest( request.getId(), // The default version value used in org.opensearch.rest.action.RestActions.parseVersion() -3L, @@ -506,7 +505,7 @@ private QueryBuilder generateQuery(SearchTopAnomalyResultRequest request) { private AggregationBuilder generateAggregation(SearchTopAnomalyResultRequest request) { List> sources = new ArrayList<>(); for (String categoryField : request.getCategoryFields()) { - Script script = getScriptForCategoryField(categoryField); + Script script = QueryUtil.getScriptForCategoryField(categoryField); sources.add(new TermsValuesSourceBuilder(categoryField).script(script)); } @@ -529,36 +528,6 @@ private AggregationBuilder generateAggregation(SearchTopAnomalyResultRequest req .subAggregation(bucketSort); } - /** - * Generates the painless script to fetch results that have an entity name matching the passed-in category field. - * - * @param categoryField the category field to be used as a source - * @return the painless script used to get all docs with entity name values matching the category field - */ - private Script getScriptForCategoryField(String categoryField) { - StringBuilder builder = new StringBuilder() - .append("String value = null;") - .append("if (params == null || params._source == null || params._source.entity == null) {") - .append("return \"\"") - .append("}") - .append("for (item in params._source.entity) {") - .append("if (item[\"name\"] == params[\"categoryField\"]) {") - .append("value = item['value'];") - .append("break;") - .append("}") - .append("}") - .append("return value;"); - - // The last argument contains the K/V pair to inject the categoryField value into the script - return new Script( - ScriptType.INLINE, - "painless", - builder.toString(), - Collections.emptyMap(), - ImmutableMap.of("categoryField", categoryField) - ); - } - /** * Creates a descending-ordered List from a min heap. * diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java index 3c1f53d9d..8172aeeaf 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; -public class StatsAnomalyDetectorAction extends ActionType { +public class StatsAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/stats"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/stats"; public static final StatsAnomalyDetectorAction INSTANCE = new StatsAnomalyDetectorAction(); private StatsAnomalyDetectorAction() { - super(NAME, StatsAnomalyDetectorResponse::new); + super(NAME, StatsTimeSeriesResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java index 5ce2c2319..d96887e12 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java @@ -11,47 +11,31 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_STATS; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.tasks.Task; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.BaseStatsTransportAction; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.StatsResponse; import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.transport.TransportService; -public class StatsAnomalyDetectorTransportAction extends HandledTransportAction { +public class StatsAnomalyDetectorTransportAction extends BaseStatsTransportAction { public static final String DETECTOR_TYPE_AGG = "detector_type_agg"; - private final Logger logger = LogManager.getLogger(StatsAnomalyDetectorTransportAction.class); - - private final Client client; - private final ADStats adStats; - private final ClusterService clusterService; @Inject public StatsAnomalyDetectorTransportAction( @@ -62,55 +46,7 @@ public StatsAnomalyDetectorTransportAction( ClusterService clusterService ) { - super(StatsAnomalyDetectorAction.NAME, transportService, actionFilters, ADStatsRequest::new); - this.client = client; - this.adStats = adStats; - this.clusterService = clusterService; - } - - @Override - protected void doExecute(Task task, ADStatsRequest request, ActionListener actionListener) { - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_STATS); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - getStats(client, listener, request); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } - } - - /** - * Make the 2 requests to get the node and cluster statistics - * - * @param client Client - * @param listener Listener to send response - * @param adStatsRequest Request containing stats to be retrieved - */ - private void getStats(Client client, ActionListener listener, ADStatsRequest adStatsRequest) { - // Use MultiResponsesDelegateActionListener to execute 2 async requests and create the response once they finish - MultiResponsesDelegateActionListener delegateListener = new MultiResponsesDelegateActionListener<>( - getRestStatsListener(listener), - 2, - "Unable to return AD Stats", - false - ); - - getClusterStats(client, delegateListener, adStatsRequest); - getNodeStats(client, delegateListener, adStatsRequest); - } - - /** - * Listener sends response once Node Stats and Cluster Stats are gathered - * - * @param listener Listener to send response - * @return ActionListener for ADStatsResponse - */ - private ActionListener getRestStatsListener(ActionListener listener) { - return ActionListener - .wrap( - adStatsResponse -> { listener.onResponse(new StatsAnomalyDetectorResponse(adStatsResponse)); }, - exception -> listener.onFailure(new OpenSearchStatusException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)) - ); + super(transportService, actionFilters, client, adStats, clusterService, StatsAnomalyDetectorAction.NAME); } /** @@ -121,15 +57,16 @@ private ActionListener getRestStatsListener(ActionListener listener, - ADStatsRequest adStatsRequest + MultiResponsesDelegateActionListener listener, + StatsRequest adStatsRequest ) { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); if ((adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName()) - || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()) - || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.HC_DETECTOR_COUNT.getName())) && clusterService.state().getRoutingTable().hasIndex(CommonName.CONFIG_INDEX)) { TermsAggregationBuilder termsAgg = AggregationBuilders.terms(DETECTOR_TYPE_AGG).field(AnomalyDetector.DETECTOR_TYPE_FIELD); @@ -156,13 +93,13 @@ private void getClusterStats( } } if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.DETECTOR_COUNT.getName()).setValue(totalDetectors); + stats.getStat(StatNames.DETECTOR_COUNT.getName()).setValue(totalDetectors); } - if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()).setValue(totalSingleEntityDetectors); + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())) { + stats.getStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()).setValue(totalSingleEntityDetectors); } - if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName()).setValue(totalMultiEntityDetectors); + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.HC_DETECTOR_COUNT.getName())) { + stats.getStat(StatNames.HC_DETECTOR_COUNT.getName()).setValue(totalMultiEntityDetectors); } adStatsResponse.setClusterStats(getClusterStatsMap(adStatsRequest)); listener.onResponse(adStatsResponse); @@ -173,24 +110,6 @@ private void getClusterStats( } } - /** - * Collect Cluster Stats into map to be retrieved - * - * @param adStatsRequest Request containing stats to be retrieved - * @return Map containing Cluster Stats - */ - private Map getClusterStatsMap(ADStatsRequest adStatsRequest) { - Map clusterStats = new HashMap<>(); - Set statsToBeRetrieved = adStatsRequest.getStatsToBeRetrieved(); - adStats - .getClusterStats() - .entrySet() - .stream() - .filter(s -> statsToBeRetrieved.contains(s.getKey())) - .forEach(s -> clusterStats.put(s.getKey(), s.getValue().getValue())); - return clusterStats; - } - /** * Make async request to get the Anomaly Detection statistics from each node and, onResponse, set the * ADStatsNodesResponse field of ADStatsResponse @@ -199,14 +118,11 @@ private Map getClusterStatsMap(ADStatsRequest adStatsRequest) { * @param listener MultiResponsesDelegateActionListener to be used once both requests complete * @param adStatsRequest Request containing stats to be retrieved */ - private void getNodeStats( - Client client, - MultiResponsesDelegateActionListener listener, - ADStatsRequest adStatsRequest - ) { + @Override + protected void getNodeStats(Client client, MultiResponsesDelegateActionListener listener, StatsRequest adStatsRequest) { client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { - ADStatsResponse restADStatsResponse = new ADStatsResponse(); - restADStatsResponse.setADStatsNodesResponse(adStatsResponse); + StatsResponse restADStatsResponse = new StatsResponse(); + restADStatsResponse.setStatsNodesResponse(adStatsResponse); listener.onResponse(restADStatsResponse); }, listener::onFailure)); } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java index 5c7182920..15f617e78 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.StopConfigResponse; -public class StopDetectorAction extends ActionType { +public class StopDetectorAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/stop"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/stop"; public static final StopDetectorAction INSTANCE = new StopDetectorAction(); private StopDetectorAction() { - super(NAME, StopDetectorResponse::new); + super(NAME, StopConfigResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java index deafd8854..074165a35 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java @@ -27,10 +27,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; -public class StopDetectorTransportAction extends HandledTransportAction { +public class StopDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(StopDetectorTransportAction.class); @@ -44,19 +47,19 @@ public StopDetectorTransportAction( ActionFilters actionFilters, Client client ) { - super(StopDetectorAction.NAME, transportService, actionFilters, StopDetectorRequest::new); + super(StopDetectorAction.NAME, transportService, actionFilters, StopConfigRequest::new); this.client = client; this.nodeFilter = nodeFilter; } @Override - protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { - StopDetectorRequest request = StopDetectorRequest.fromActionRequest(actionRequest); - String adID = request.getAdID(); + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + StopConfigRequest request = StopConfigRequest.fromActionRequest(actionRequest); + String adID = request.getConfigID(); try { DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); DeleteModelRequest modelDeleteRequest = new DeleteModelRequest(adID, dataNodes); - client.execute(DeleteModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + client.execute(DeleteADModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { if (response.hasFailures()) { LOG.warn("Cannot delete all models of detector {}", adID); for (FailedNodeException failedNodeException : response.failures()) { @@ -64,14 +67,14 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< } // if customers are using an updated detector and we haven't deleted old // checkpoints, customer would have trouble - listener.onResponse(new StopDetectorResponse(false)); + listener.onResponse(new StopConfigResponse(false)); } else { LOG.info("models of detector {} get deleted", adID); - listener.onResponse(new StopDetectorResponse(true)); + listener.onResponse(new StopConfigResponse(true)); } }, exception -> { LOG.error(new ParameterizedMessage("Deletion of detector [{}] has exception.", adID), exception); - listener.onResponse(new StopDetectorResponse(false)); + listener.onResponse(new StopConfigResponse(false)); })); } catch (Exception e) { LOG.error(FAIL_TO_STOP_DETECTOR + " " + adID, e); diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java index 1561c08dc..f8a81252a 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ThresholdResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "threshold/result"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "threshold/result"; public static final ThresholdResultAction INSTANCE = new ThresholdResultAction(); private ThresholdResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java index 053d9729b..9c60fcd7f 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java @@ -15,7 +15,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; @@ -24,10 +24,10 @@ public class ThresholdResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(ThresholdResultTransportAction.class); - private ModelManager manager; + private ADModelManager manager; @Inject - public ThresholdResultTransportAction(ActionFilters actionFilters, TransportService transportService, ModelManager manager) { + public ThresholdResultTransportAction(ActionFilters actionFilters, TransportService transportService, ADModelManager manager) { super(ThresholdResultAction.NAME, transportService, actionFilters, ThresholdResultRequest::new); this.manager = manager; } diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java index 432166ac2..cf3f2325a 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java @@ -12,14 +12,15 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.ValidateConfigResponse; -public class ValidateAnomalyDetectorAction extends ActionType { +public class ValidateAnomalyDetectorAction extends ActionType { - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/validate"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/validate"; public static final ValidateAnomalyDetectorAction INSTANCE = new ValidateAnomalyDetectorAction(); private ValidateAnomalyDetectorAction() { - super(NAME, ValidateAnomalyDetectorResponse::new); + super(NAME, ValidateConfigResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java deleted file mode 100644 index 3ee1f0a6e..000000000 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.io.IOException; - -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; - -public class ValidateAnomalyDetectorRequest extends ActionRequest { - - private final AnomalyDetector detector; - private final String validationType; - private final Integer maxSingleEntityAnomalyDetectors; - private final Integer maxMultiEntityAnomalyDetectors; - private final Integer maxAnomalyFeatures; - private final TimeValue requestTimeout; - - public ValidateAnomalyDetectorRequest(StreamInput in) throws IOException { - super(in); - detector = new AnomalyDetector(in); - validationType = in.readString(); - maxSingleEntityAnomalyDetectors = in.readInt(); - maxMultiEntityAnomalyDetectors = in.readInt(); - maxAnomalyFeatures = in.readInt(); - requestTimeout = in.readTimeValue(); - } - - public ValidateAnomalyDetectorRequest( - AnomalyDetector detector, - String validationType, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, - TimeValue requestTimeout - ) { - this.detector = detector; - this.validationType = validationType; - this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; - this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; - this.maxAnomalyFeatures = maxAnomalyFeatures; - this.requestTimeout = requestTimeout; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - detector.writeTo(out); - out.writeString(validationType); - out.writeInt(maxSingleEntityAnomalyDetectors); - out.writeInt(maxMultiEntityAnomalyDetectors); - out.writeInt(maxAnomalyFeatures); - out.writeTimeValue(requestTimeout); - } - - @Override - public ActionRequestValidationException validate() { - return null; - } - - public AnomalyDetector getDetector() { - return detector; - } - - public String getValidationType() { - return validationType; - } - - public Integer getMaxSingleEntityAnomalyDetectors() { - return maxSingleEntityAnomalyDetectors; - } - - public Integer getMaxMultiEntityAnomalyDetectors() { - return maxMultiEntityAnomalyDetectors; - } - - public Integer getMaxAnomalyFeatures() { - return maxAnomalyFeatures; - } - - public TimeValue getRequestTimeout() { - return requestTimeout; - } -} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java index 16eec43ac..db43e038c 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java @@ -12,61 +12,31 @@ package org.opensearch.ad.transport; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; - -import java.time.Clock; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.rest.RestRequest; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.feature.SearchFeatureDao; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.ValidationAspect; -import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.rest.handler.Processor; +import org.opensearch.timeseries.transport.BaseValidateConfigTransportAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; -public class ValidateAnomalyDetectorTransportAction extends - HandledTransportAction { - private static final Logger logger = LogManager.getLogger(ValidateAnomalyDetectorTransportAction.class); - - private final Client client; - private final SecurityClientUtil clientUtil; - private final ClusterService clusterService; - private final NamedXContentRegistry xContentRegistry; - private final ADIndexManagement anomalyDetectionIndices; - private final SearchFeatureDao searchFeatureDao; - private volatile Boolean filterByEnabled; - private Clock clock; - private Settings settings; +public class ValidateAnomalyDetectorTransportAction extends BaseValidateConfigTransportAction { + public static final Logger logger = LogManager.getLogger(ValidateAnomalyDetectorTransportAction.class); @Inject public ValidateAnomalyDetectorTransportAction( @@ -80,176 +50,41 @@ public ValidateAnomalyDetectorTransportAction( TransportService transportService, SearchFeatureDao searchFeatureDao ) { - super(ValidateAnomalyDetectorAction.NAME, transportService, actionFilters, ValidateAnomalyDetectorRequest::new); - this.client = client; - this.clientUtil = clientUtil; - this.clusterService = clusterService; - this.xContentRegistry = xContentRegistry; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.searchFeatureDao = searchFeatureDao; - this.clock = Clock.systemUTC(); - this.settings = settings; + super( + ValidateAnomalyDetectorAction.NAME, + client, + clientUtil, + clusterService, + xContentRegistry, + settings, + anomalyDetectionIndices, + actionFilters, + transportService, + searchFeatureDao, + AD_FILTER_BY_BACKEND_ROLES + ); } @Override - protected void doExecute(Task task, ValidateAnomalyDetectorRequest request, ActionListener listener) { - User user = getUserContext(client); - AnomalyDetector anomalyDetector = request.getDetector(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute(user, listener, () -> validateExecute(request, user, context, listener)); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } - } - - private void resolveUserAndExecute( - User requestedUser, - ActionListener listener, - ExecutorFunction function - ) { - try { - // Check if user has backend roles - // When filter by is enabled, block users validating detectors who do not have backend roles. - if (filterByEnabled && !checkFilterByBackendRoles(requestedUser, listener)) { - return; - } - // Validate Detector - function.execute(); - } catch (Exception e) { - listener.onFailure(e); - } - } - - private void validateExecute( - ValidateAnomalyDetectorRequest request, - User user, - ThreadContext.StoredContext storedContext, - ActionListener listener - ) { - storedContext.restore(); - AnomalyDetector detector = request.getDetector(); - ActionListener validateListener = ActionListener.wrap(response -> { - logger.debug("Result of validation process " + response); - // forcing response to be empty - listener.onResponse(new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null)); - }, exception -> { - if (exception instanceof ValidationException) { - // ADValidationException is converted as validation issues returned as response to user - DetectorValidationIssue issue = parseADValidationException((ValidationException) exception); - listener.onResponse(new ValidateAnomalyDetectorResponse(issue)); - return; - } - logger.error(exception); - listener.onFailure(exception); - }); - checkIndicesAndExecute(detector.getIndices(), () -> { - ValidateAnomalyDetectorActionHandler handler = new ValidateAnomalyDetectorActionHandler( - clusterService, - client, - clientUtil, - validateListener, - anomalyDetectionIndices, - detector, - request.getRequestTimeout(), - request.getMaxSingleEntityAnomalyDetectors(), - request.getMaxMultiEntityAnomalyDetectors(), - request.getMaxAnomalyFeatures(), - RestRequest.Method.POST, - xContentRegistry, - user, - searchFeatureDao, - request.getValidationType(), - clock, - settings - ); - try { - handler.start(); - } catch (Exception exception) { - String errorMessage = String - .format(Locale.ROOT, "Unknown exception caught while validating detector %s", request.getDetector()); - logger.error(errorMessage, exception); - listener.onFailure(exception); - } - }, listener); - } - - protected DetectorValidationIssue parseADValidationException(ValidationException exception) { - String originalErrorMessage = exception.getMessage(); - String errorMessage = ""; - Map subIssues = null; - IntervalTimeConfiguration intervalSuggestion = exception.getIntervalSuggestion(); - switch (exception.getType()) { - case FEATURE_ATTRIBUTES: - int firstLeftBracketIndex = originalErrorMessage.indexOf("["); - int lastRightBracketIndex = originalErrorMessage.lastIndexOf("]"); - if (firstLeftBracketIndex != -1) { - // if feature issue messages are between square brackets like - // [Feature has issue: A, Feature has issue: B] - errorMessage = originalErrorMessage.substring(firstLeftBracketIndex + 1, lastRightBracketIndex); - subIssues = getFeatureSubIssuesFromErrorMessage(errorMessage); - } else { - // features having issue like over max feature limit, duplicate feature name, etc. - errorMessage = originalErrorMessage; - } - break; - case NAME: - case CATEGORY: - case DETECTION_INTERVAL: - case FILTER_QUERY: - case TIMEFIELD_FIELD: - case SHINGLE_SIZE_FIELD: - case WINDOW_DELAY: - case RESULT_INDEX: - case GENERAL_SETTINGS: - case AGGREGATION: - case TIMEOUT: - case INDICES: - errorMessage = originalErrorMessage; - break; - } - return new DetectorValidationIssue(exception.getAspect(), exception.getType(), errorMessage, subIssues, intervalSuggestion); - } - - // Example of method output: - // String input:Feature has invalid query returning empty aggregated data: average_total_rev, Feature has invalid query causing runtime - // exception: average_total_rev-2 - // output: "sub_issues": { - // "average_total_rev": "Feature has invalid query returning empty aggregated data", - // "average_total_rev-2": "Feature has invalid query causing runtime exception" - // } - private Map getFeatureSubIssuesFromErrorMessage(String errorMessage) { - Map result = new HashMap<>(); - String[] subIssueMessagesSuffix = errorMessage.split(", "); - for (int i = 0; i < subIssueMessagesSuffix.length; i++) { - result.put(subIssueMessagesSuffix[i].split(": ")[1], subIssueMessagesSuffix[i].split(": ")[0]); - } - return result; - } - - private void checkIndicesAndExecute( - List indices, - ExecutorFunction function, - ActionListener listener - ) { - SearchRequest searchRequest = new SearchRequest() - .indices(indices.toArray(new String[0])) - .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); - client.search(searchRequest, ActionListener.wrap(r -> function.execute(), e -> { - if (e instanceof IndexNotFoundException) { - // IndexNotFoundException is converted to a ADValidationException that gets - // parsed to a DetectorValidationIssue that is returned to - // the user as a response indicating index doesn't exist - DetectorValidationIssue issue = parseADValidationException( - new ValidationException(ADCommonMessages.INDEX_NOT_FOUND, ValidationIssueType.INDICES, ValidationAspect.DETECTOR) - ); - listener.onResponse(new ValidateAnomalyDetectorResponse(issue)); - return; - } - logger.error(e); - listener.onFailure(e); - })); + protected Processor createProcessor(Config detector, ValidateConfigRequest request, User user) { + return new ValidateAnomalyDetectorActionHandler( + clusterService, + client, + clientUtil, + indexManagement, + detector, + request.getRequestTimeout(), + request.getMaxSingleEntityAnomalyDetectors(), + request.getMaxMultiEntityAnomalyDetectors(), + request.getMaxAnomalyFeatures(), + request.getMaxCategoricalFields(), + RestRequest.Method.POST, + xContentRegistry, + user, + searchFeatureDao, + request.getValidationType(), + clock, + settings + ); } } diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..175017450 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.transport.ADResultBulkAction; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; + +public class ADIndexMemoryPressureAwareResultHandler extends + IndexMemoryPressureAwareResultHandler { + private static final Logger LOG = LogManager.getLogger(ADIndexMemoryPressureAwareResultHandler.class); + + @Inject + public ADIndexMemoryPressureAwareResultHandler(Client client, ADIndexManagement anomalyDetectionIndices) { + super(client, anomalyDetectionIndices); + } + + @Override + public void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { + if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new TimeSeriesException("no result to save")); + return; + } + client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java index 8e23243d8..7a5312652 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java +++ b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java @@ -11,74 +11,18 @@ package org.opensearch.ad.transport.handler; -import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_SEARCH; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.isAdmin; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.transport.handler.SearchHandler; /** * Handle general search request, check user role and return search response. */ -public class ADSearchHandler { - private final Logger logger = LogManager.getLogger(ADSearchHandler.class); - private final Client client; - private volatile Boolean filterEnabled; +public class ADSearchHandler extends SearchHandler { public ADSearchHandler(Settings settings, ClusterService clusterService, Client client) { - this.client = client; - filterEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); - } - - /** - * Validate user role, add backend role filter if filter enabled - * and execute search. - * - * @param request search request - * @param actionListener action listerner - */ - public void search(SearchRequest request, ActionListener actionListener) { - User user = getUserContext(client); - ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_SEARCH); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - validateRole(request, user, listener); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } + super(settings, clusterService, client, AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES); } - - private void validateRole(SearchRequest request, User user, ActionListener listener) { - if (user == null || !filterEnabled || isAdmin(user)) { - // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin - // Case 2: If Security is enabled and filter is disabled, proceed with search as - // user is already authenticated to hit this API. - // case 3: user is admin which means we don't have to check backend role filtering - client.search(request, listener); - } else { - // Security is enabled, filter is enabled and user isn't admin - try { - addUserBackendRolesFilter(user, request.source()); - logger.debug("Filtering result by " + user.getBackendRoles()); - client.search(request, listener); - } catch (Exception e) { - listener.onFailure(e); - } - } - } - } diff --git a/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java b/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java deleted file mode 100644 index 13f7e16e7..000000000 --- a/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport.handler; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.ExceptionsHelper; -import org.opensearch.ResourceAlreadyExistsException; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.transport.ADResultBulkAction; -import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.client.Client; -import org.opensearch.cluster.block.ClusterBlockLevel; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.util.ClientUtil; - -/** - * EntityResultTransportAction depends on this class. I cannot use - * AnomalyIndexHandler < AnomalyResult > . All transport actions - * needs dependency injection. Guice has a hard time initializing generics class - * AnomalyIndexHandler < AnomalyResult > due to type erasure. - * To avoid that, I create a class with a built-in details so - * that Guice would be able to work out the details. - * - */ -public class MultiEntityResultHandler extends AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(MultiEntityResultHandler.class); - // package private for testing - static final String SUCCESS_SAVING_RESULT_MSG = "Result saved successfully."; - static final String CANNOT_SAVE_RESULT_ERR_MSG = "Cannot save results due to write block."; - - @Inject - public MultiEntityResultHandler( - Client client, - Settings settings, - ThreadPool threadPool, - ADIndexManagement anomalyDetectionIndices, - ClientUtil clientUtil, - IndexUtils indexUtils, - ClusterService clusterService - ) { - super( - client, - settings, - threadPool, - ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, - anomalyDetectionIndices, - clientUtil, - indexUtils, - clusterService - ); - } - - /** - * Execute the bulk request - * @param currentBulkRequest The bulk request - * @param listener callback after flushing - */ - public void flush(ADResultBulkRequest currentBulkRequest, ActionListener listener) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { - listener.onFailure(new TimeSeriesException(CANNOT_SAVE_RESULT_ERR_MSG)); - return; - } - - try { - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - bulk(currentBulkRequest, listener); - } else { - LOG.warn("Creating result index with mappings call not acknowledged."); - listener.onFailure(new TimeSeriesException("", "Creating result index with mappings call not acknowledged.")); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - bulk(currentBulkRequest, listener); - } else { - LOG.warn("Unexpected error creating result index", exception); - listener.onFailure(exception); - } - })); - } else { - bulk(currentBulkRequest, listener); - } - } catch (Exception e) { - LOG.warn("Error in bulking results", e); - listener.onFailure(e); - } - } - - private void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { - if (currentBulkRequest.numberOfActions() <= 0) { - listener.onFailure(new TimeSeriesException("no result to save")); - return; - } - client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { - LOG.debug(SUCCESS_SAVING_RESULT_MSG); - listener.onResponse(response); - }, exception -> { - LOG.error("Error in bulking results", exception); - listener.onFailure(exception); - })); - } -} diff --git a/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java b/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java new file mode 100644 index 000000000..16db8cb0d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Optional; + +import org.opensearch.client.Client; +import org.opensearch.commons.authuser.User; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class ExecuteForecastResultResponseRecorder extends + ExecuteResultResponseRecorder { + + public ExecuteForecastResultResponseRecorder( + ForecastIndexManagement indexManagement, + ResultBulkIndexingHandler resultHandler, + ForecastTaskManager taskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + Client client, + NodeStateManager nodeStateManager, + TaskCacheManager taskCacheManager, + int rcfMinSamples + ) { + super( + indexManagement, + resultHandler, + taskManager, + nodeFilter, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + client, + nodeStateManager, + taskCacheManager, + rcfMinSamples, + ForecastIndex.RESULT, + AnalysisType.FORECAST, + ForecastProfileAction.INSTANCE + ); + } + + @Override + protected ForecastResult createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user + ) { + return new ForecastResult( + configId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executeEndTime, + Instant.now(), + errorMessage, + Optional.empty(), // single-stream forecasters have no entity + user, + indexManagement.getSchemaVersion(resultIndex) + ); + } + + @Override + protected void updateRealtimeTask(ResultResponse response, String configId) { + if (taskManager.skipUpdateRealtimeTask(configId, response.getError())) { + return; + } + + delayedUpdate(response, configId); + } +} diff --git a/src/main/java/org/opensearch/forecast/ForecastEntityProfileRunner.java b/src/main/java/org/opensearch/forecast/ForecastEntityProfileRunner.java new file mode 100644 index 000000000..e1dcaebfd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastEntityProfileRunner.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.transport.ForecastEntityProfileAction; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.EntityProfileRunner; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ForecastEntityProfileRunner extends EntityProfileRunner { + + public ForecastEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ) { + super( + client, + clientUtil, + xContentRegistry, + requiredSamples, + Forecaster::parse, + ForecastNumericSetting.maxCategoricalFields(), + AnalysisType.FORECAST, + ForecastEntityProfileAction.INSTANCE, + ForecastIndex.RESULT.getIndexName(), + ForecastCommonName.FORECASTER_ID_KEY + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java b/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java new file mode 100644 index 000000000..d6128d030 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import java.time.Instant; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ForecastJobProcessor extends + JobProcessor { + + private static final Logger log = LogManager.getLogger(ForecastJobProcessor.class); + + private static ForecastJobProcessor INSTANCE; + + public static ForecastJobProcessor getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ForecastJobProcessor.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new ForecastJobProcessor(); + return INSTANCE; + } + } + + private ForecastJobProcessor() { + // Singleton class, use getJobRunnerInstance method instead of constructor + super(AnalysisType.FORECAST, TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, ForecastResultAction.INSTANCE); + } + + public void registerSettings(Settings settings) { + super.registerSettings(settings, ForecastSettings.FORECAST_MAX_RETRY_FOR_END_RUN_EXCEPTION); + } + + @Override + protected ResultRequest createResultRequest(String configId, long start, long end) { + return new ForecastResultRequest(configId, start, end); + } + + @Override + protected void validateResultIndexAndRunJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteForecastResultResponseRecorder recorder, + Config detector + ) { + ActionListener listener = ActionListener.wrap(r -> { log.debug("Result index is valid"); }, e -> { + Exception exception = new EndRunException(configId, e.getMessage(), false); + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, exception, recorder, detector); + }); + String resultIndex = jobParameter.getCustomResultIndex(); + if (resultIndex == null) { + indexManagement.validateDefaultResultIndexForBackendJob(configId, user, roles, () -> { + listener.onResponse(true); + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + }, listener); + } else { + indexManagement.validateCustomIndexForBackendJob(resultIndex, configId, user, roles, () -> { + listener.onResponse(true); + runJob(jobParameter, lockService, lock, executionStartTime, executionEndTime, configId, user, roles, recorder, detector); + }, listener); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/ForecastProfileRunner.java b/src/main/java/org/opensearch/forecast/ForecastProfileRunner.java new file mode 100644 index 000000000..10c1301fd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastProfileRunner.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskProfile; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.ForecasterProfile; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ProfileRunner; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastProfileRunner extends + ProfileRunner { + + public ForecastProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + ForecastTaskManager forecastTaskManager, + ForecastTaskProfileRunner taskProfileRunner + ) { + super( + client, + clientUtil, + xContentRegistry, + nodeFilter, + requiredSamples, + transportService, + forecastTaskManager, + AnalysisType.FORECAST, + ForecastTaskType.REALTIME_TASK_TYPES, + ForecastTaskType.RUN_ONCE_TASK_TYPES, + ForecastNumericSetting.maxCategoricalFields(), + ProfileName.FORECAST_TASK, + ForecastProfileAction.INSTANCE, + Forecaster::parse, + taskProfileRunner + ); + } + + @Override + protected ForecasterProfile.Builder createProfileBuilder() { + return new ForecasterProfile.Builder(); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java b/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java new file mode 100644 index 000000000..f7deb5578 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastTaskProfileRunner.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskProfile; +import org.opensearch.timeseries.TaskProfileRunner; + +public class ForecastTaskProfileRunner implements TaskProfileRunner { + + @Override + public void getTaskProfile(ForecastTask configLevelTask, ActionListener listener) { + // return null since forecasting have no in-memory task profiles as AD + listener.onResponse(null); + } + +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java b/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java new file mode 100644 index 000000000..74fb06d89 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.caching; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.CacheBuffer; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCacheBuffer extends + CacheBuffer { + + public ForecastCacheBuffer( + int minimumCapacity, + Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, + Duration modelTtl, + long memoryConsumptionPerEntity, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastCheckpointMaintainWorker checkpointMaintainQueue, + String configId, + long intervalSecs + ) { + super( + minimumCapacity, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + configId, + intervalSecs, + Origin.REAL_TIME_FORECASTER + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java b/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java new file mode 100644 index 000000000..f93982cc2 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.caching; + +import org.opensearch.timeseries.caching.CacheProvider; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCacheProvider extends CacheProvider { + +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java b/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java new file mode 100644 index 000000000..5c527c9ba --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java @@ -0,0 +1,119 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.caching; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Optional; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.PriorityCache; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastPriorityCache extends + PriorityCache { + private ForecastCheckpointWriteWorker checkpointWriteQueue; + private ForecastCheckpointMaintainWorker checkpointMaintainQueue; + + public ForecastPriorityCache( + ForecastCheckpointDao checkpointDao, + int hcDedicatedCacheSize, + Setting checkpointTtl, + int maxInactiveStates, + MemoryTracker memoryTracker, + int numberOfTrees, + Clock clock, + ClusterService clusterService, + Duration modelTtl, + ThreadPool threadPool, + String threadPoolName, + int maintenanceFreqConstant, + Settings settings, + Setting checkpointSavingFreq, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastCheckpointMaintainWorker checkpointMaintainQueue + ) { + super( + checkpointDao, + hcDedicatedCacheSize, + checkpointTtl, + maxInactiveStates, + memoryTracker, + numberOfTrees, + clock, + clusterService, + modelTtl, + threadPool, + threadPoolName, + maintenanceFreqConstant, + settings, + checkpointSavingFreq, + Origin.REAL_TIME_FORECASTER, + FORECAST_DEDICATED_CACHE_SIZE, + FORECAST_MODEL_MAX_SIZE_PERCENTAGE + ); + + this.checkpointWriteQueue = checkpointWriteQueue; + this.checkpointMaintainQueue = checkpointMaintainQueue; + } + + @Override + protected ForecastCacheBuffer createEmptyCacheBuffer(Config config, long requiredMemory) { + return new ForecastCacheBuffer( + config.isHighCardinality() ? hcDedicatedCacheSize : 1, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + requiredMemory, + checkpointWriteQueue, + checkpointMaintainQueue, + config.getId(), + config.getIntervalInSeconds() + ); + } + + @Override + protected ModelState createEmptyModelState(String modelId, String forecasterId) { + return new ModelState<>( + null, + modelId, + forecasterId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + new Sample(), + Optional.empty(), + new ArrayDeque<>() + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/cluster/diskcleanup/ForecastCheckpointIndexRetention.java b/src/main/java/org/opensearch/forecast/cluster/diskcleanup/ForecastCheckpointIndexRetention.java new file mode 100644 index 000000000..92c1ac1e7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/cluster/diskcleanup/ForecastCheckpointIndexRetention.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.cluster.diskcleanup; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; + +public class ForecastCheckpointIndexRetention extends BaseModelCheckpointIndexRetention { + + public ForecastCheckpointIndexRetention(Duration defaultCheckpointTtl, Clock clock, IndexCleanup indexCleanup) { + super(defaultCheckpointTtl, clock, indexCleanup, ForecastIndex.CHECKPOINT.getIndexName()); + } + +} diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java b/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java index 46de0c762..deb31cad7 100644 --- a/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java @@ -35,6 +35,7 @@ public class ForecastCommonMessages { public static String FAIL_TO_FIND_FORECASTER_MSG = "Can not find forecaster with id: "; public static final String FORECASTER_ID_MISSING_MSG = "Forecaster ID is missing"; public static final String INVALID_TIMESTAMP_ERR_MSG = "timestamp is invalid"; + public static String FAIL_TO_GET_FORECASTER = "Fail to get forecaster"; // ====================================== // Security @@ -45,10 +46,17 @@ public class ForecastCommonMessages { // ====================================== // Used for custom forecast result index // ====================================== + public static String CAN_NOT_FIND_RESULT_INDEX = "Can't find result index "; public static String INVALID_RESULT_INDEX_PREFIX = "Result index must start with " + CUSTOM_RESULT_INDEX_PREFIX; // ====================================== // Task // ====================================== public static String FORECASTER_IS_RUNNING = "Forecaster is already running"; + + // ====================================== + // Job + // ====================================== + public static String FAIL_TO_START_FORECASTER = "Fail to start forecaster"; + public static String FAIL_TO_STOP_FORECASTER = "Fail to stop forecaster"; } diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java b/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java index 8edaf2d2b..f9dc48985 100644 --- a/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java @@ -45,4 +45,9 @@ public class ForecastCommonName { // Used in stats API // ====================================== public static final String FORECASTER_ID_KEY = "forecaster_id"; + + // ====================================== + // Historical forecasters + // ====================================== + public static final String FORECAST_TASK = "forecast_task"; } diff --git a/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java b/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java index f27aa749e..a9e1b1abb 100644 --- a/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java +++ b/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java @@ -24,17 +24,26 @@ import java.io.IOException; import java.util.EnumMap; +import java.util.List; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.alias.get.GetAliasesRequest; +import org.opensearch.action.admin.indices.alias.get.GetAliasesResponse; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.IndicesOptions; import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.InjectSecurity; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -42,6 +51,8 @@ import org.opensearch.forecast.model.ForecastResult; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.indices.IndexManagement; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -274,4 +285,104 @@ public void initCustomResultIndexDirectly(String resultIndex, ActionListener void validateDefaultResultIndexForBackendJob( + String configId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener + ) { + if (doesAliasExist(ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS)) { + handleExistingIndex(configId, user, roles, function, listener); + } else { + initDefaultResultIndex(configId, user, roles, function, listener); + } + } + + private void handleExistingIndex( + String configId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener + ) { + GetAliasesRequest getAliasRequest = new GetAliasesRequest() + .aliases(ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS) + .indicesOptions(IndicesOptions.lenientExpandOpenHidden()); + + adminClient.indices().getAliases(getAliasRequest, ActionListener.wrap(getAliasResponse -> { + String concreteIndex = getConcreteIndexFromAlias(getAliasResponse); + if (concreteIndex == null) { + listener.onFailure(new EndRunException("Result index alias mapping is empty", false)); + return; + } + + if (!isValidResultIndexMapping(concreteIndex)) { + listener.onFailure(new EndRunException("Result index mapping is not correct", false)); + return; + } + + executeWithSecurityContext(configId, user, roles, function, listener, concreteIndex); + + }, listener::onFailure)); + } + + private String getConcreteIndexFromAlias(GetAliasesResponse getAliasResponse) { + for (Map.Entry> entry : getAliasResponse.getAliases().entrySet()) { + if (!entry.getValue().isEmpty()) { + return entry.getKey(); + } + } + return null; + } + + private void initDefaultResultIndex( + String configId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener + ) { + initDefaultResultIndexDirectly(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + executeWithSecurityContext(configId, user, roles, function, listener, ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS); + } else { + String error = "Creating result index with mappings call not acknowledged"; + logger.error(error); + listener.onFailure(new TimeSeriesException(error)); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + executeWithSecurityContext(configId, user, roles, function, listener, ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS); + } else { + listener.onFailure(exception); + } + })); + } + + private void executeWithSecurityContext( + String securityLogId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener, + String indexOrAlias + ) { + try (InjectSecurity injectSecurity = new InjectSecurity(securityLogId, settings, client.threadPool().getThreadContext())) { + injectSecurity.inject(user, roles); + ActionListener wrappedListener = ActionListener.wrap(listener::onResponse, e -> { + injectSecurity.close(); + listener.onFailure(e); + }); + validateResultIndexAndExecute(indexOrAlias, () -> { + injectSecurity.close(); + function.execute(); + }, true, wrappedListener); + } catch (Exception e) { + logger.error("Failed to validate custom index for backend job " + securityLogId, e); + listener.onFailure(e); + } + } + } diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java b/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java new file mode 100644 index 000000000..aefeded2f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java @@ -0,0 +1,426 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.ClientUtil; + +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.state.RCFCasterMapper; +import com.amazon.randomcutforest.parkservices.state.RCFCasterState; +import com.google.gson.Gson; + +import io.protostuff.LinkedBuffer; +import io.protostuff.ProtostuffIOUtil; +import io.protostuff.Schema; + +/** + * The ForecastCheckpointDao class implements all the functionality required for fetching, updating and + * removing forecast checkpoints. + * + */ +public class ForecastCheckpointDao extends CheckpointDao { + public static final Logger logger = LogManager.getLogger(ForecastCheckpointDao.class); + static final String LAST_PROCESSED_SAMPLE_FIELD = "last_processed_sample"; + + static final String NOT_ABLE_TO_DELETE_CHECKPOINT_MSG = "Cannot delete all checkpoints of forecaster"; + + RCFCasterMapper mapper; + private Schema rcfCasterSchema; + + public ForecastCheckpointDao( + Client client, + ClientUtil clientUtil, + Gson gson, + int maxCheckpointBytes, + GenericObjectPool serializeRCFBufferPool, + int serializeRCFBufferSize, + ForecastIndexManagement indexUtil, + RCFCasterMapper mapper, + Schema rcfCasterSchema, + Clock clock + ) { + super( + client, + clientUtil, + ForecastIndex.CHECKPOINT.getIndexName(), + gson, + maxCheckpointBytes, + serializeRCFBufferPool, + serializeRCFBufferSize, + indexUtil, + clock + ); + this.mapper = mapper; + this.rcfCasterSchema = rcfCasterSchema; + } + + /** + * Puts a RCFCaster model checkpoint in the storage. Used in single-stream forecasting. + * + * @param modelId id of the model + * @param caster the RCFCaster model + * @param listener onResponse is called with null when the operation is completed + */ + public void putCasterCheckpoint(String modelId, RCFCaster caster, ActionListener listener) { + Map source = new HashMap<>(); + Optional modelCheckpoint = toCheckpoint(Optional.of(caster)); + if (modelCheckpoint.isPresent()) { + source.put(CommonName.FIELD_MODEL, modelCheckpoint.get()); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT)); + putModelCheckpoint(modelId, source, listener); + } else { + listener.onFailure(new RuntimeException("Fail to create checkpoint to save")); + } + } + + private Optional toCheckpoint(Optional caster) { + if (caster.isEmpty()) { + return Optional.empty(); + } + Optional checkpoint = Optional.empty(); + Map.Entry result = checkoutOrNewBuffer(); + LinkedBuffer buffer = result.getKey(); + boolean needCheckin = result.getValue(); + try { + checkpoint = toCheckpoint(caster, buffer); + } catch (Exception e) { + logger.error("Failed to serialize model", e); + if (needCheckin) { + try { + serializeRCFBufferPool.invalidateObject(buffer); + needCheckin = false; + } catch (Exception x) { + logger.warn("Failed to invalidate buffer", x); + } + try { + checkpoint = toCheckpoint(caster, LinkedBuffer.allocate(serializeRCFBufferSize)); + } catch (Exception ex) { + logger.warn("Failed to generate checkpoint", ex); + } + } + } finally { + if (needCheckin) { + try { + serializeRCFBufferPool.returnObject(buffer); + } catch (Exception e) { + logger.warn("Failed to return buffer to pool", e); + } + } + } + return checkpoint; + } + + private Optional toCheckpoint(Optional caster, LinkedBuffer buffer) { + if (caster.isEmpty()) { + return Optional.empty(); + } + try { + byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> { + RCFCasterState casterState = mapper.toState(caster.get()); + return ProtostuffIOUtil.toByteArray(casterState, rcfCasterSchema, buffer); + }); + return Optional.ofNullable(Base64.getEncoder().encodeToString(bytes)); + } finally { + buffer.clear(); + } + } + + /** + * Prepare for index request using the contents of the given model state. Used in HC forecasting. + * @param modelState an entity model state + * @return serialized JSON map or empty map if the state is too bloated + * @throws IOException when serialization fails + */ + @Override + public Map toIndexSource(ModelState modelState) throws IOException { + Map source = new HashMap<>(); + Optional model = modelState.getModel(); + + Optional serializedModel = toCheckpoint(model); + if (serializedModel.isPresent() && serializedModel.get().length() <= maxCheckpointBytes) { + // we cannot pass Optional as OpenSearch does not know how to serialize an Optional value + source.put(CommonName.FIELD_MODEL, serializedModel.get()); + } else { + logger + .warn( + new ParameterizedMessage( + "[{}]'s model is empty or too large: [{}] bytes", + modelState.getModelId(), + serializedModel.isPresent() ? serializedModel.get().length() : 0 + ) + ); + } + if (modelState.getSamples() != null && !(modelState.getSamples().isEmpty())) { + source.put(CommonName.ENTITY_SAMPLE_QUEUE, toCheckpoint(modelState.getSamples()).get()); + } + // if there are no samples and no model, no need to index as other information are meta data + if (!source.containsKey(CommonName.ENTITY_SAMPLE_QUEUE) && !source.containsKey(CommonName.FIELD_MODEL)) { + return source; + } + + source.put(ForecastCommonName.FORECASTER_ID_KEY, modelState.getConfigId()); + if (modelState.getLastProcessedSample() != null) { + source.put(LAST_PROCESSED_SAMPLE_FIELD, modelState.getLastProcessedSample()); + } + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT)); + + Optional entity = modelState.getEntity(); + if (entity.isPresent()) { + source.put(CommonName.ENTITY_KEY, entity.get()); + } + + return source; + } + + private void deserializeRCFCasterModel(GetResponse response, String rcfModelId, ActionListener> listener) { + Object model = null; + if (response.isExists()) { + try { + model = response.getSource().get(CommonName.FIELD_MODEL); + listener.onResponse(Optional.ofNullable(toRCFCaster((String) model))); + + } catch (Exception e) { + logger.error(new ParameterizedMessage("Unexpected error when deserializing [{}]", rcfModelId), e); + listener.onResponse(Optional.empty()); + } + } else { + listener.onResponse(Optional.empty()); + } + } + + RCFCaster toRCFCaster(String checkpoint) { + RCFCaster rcfCaster = null; + if (checkpoint != null && checkpoint.length() > 0) { + try { + byte[] bytes = Base64.getDecoder().decode(checkpoint); + RCFCasterState state = rcfCasterSchema.newMessage(); + AccessController.doPrivileged((PrivilegedAction) () -> { + ProtostuffIOUtil.mergeFrom(bytes, state, rcfCasterSchema); + return null; + }); + rcfCaster = mapper.toModel(state); + } catch (RuntimeException e) { + logger.error("Failed to deserialize RCFCaster model", e); + } + } + return rcfCaster; + } + + /** + * Returns to listener the checkpoint for the RCFCaster model. Used in single-stream forecasting. + * + * @param modelId id of the model + * @param listener onResponse is called with the model checkpoint, or empty for no such model + */ + public void getCasterModel(String modelId, ActionListener> listener) { + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener.wrap(response -> deserializeRCFCasterModel(response, modelId, listener), exception -> { + // expected exception, don't print stack trace + if (exception instanceof IndexNotFoundException) { + listener.onResponse(Optional.empty()); + } else { + listener.onFailure(exception); + } + }) + ); + } + + /** + * Load json checkpoint into models. Used in HC forecasting. + * + * @param checkpoint json checkpoint contents + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time; or empty if + * the raw checkpoint is too large + */ + @Override + protected ModelState fromEntityModelCheckpoint(Map checkpoint, String modelId, String configId) { + try { + return AccessController.doPrivileged((PrivilegedAction>) () -> { + + RCFCaster rcfCaster = loadRCFCaster(checkpoint, modelId); + + Entity entity = null; + Object serializedEntity = checkpoint.get(CommonName.ENTITY_KEY); + if (serializedEntity != null) { + try { + entity = Entity.fromJsonArray(serializedEntity); + } catch (Exception e) { + logger.error(new ParameterizedMessage("fail to parse entity", serializedEntity), e); + } + } + + ModelState modelState = new ModelState( + rcfCaster, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + loadLastProcessedSample(checkpoint, modelId), + Optional.ofNullable(entity), + loadSampleQueue(checkpoint, modelId) + ); + + modelState.setLastCheckpointTime(loadTimestamp(checkpoint, modelId)); + + return modelState; + }); + } catch (Exception e) { + logger.warn("Exception while deserializing checkpoint " + modelId, e); + // checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return null; + } + } + + /** + * Delete checkpoints associated with a forecaster. Used in HC forecaster. + * @param forecasterId Forecaster Id + */ + public void deleteModelCheckpointByForecasterId(String forecasterId) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(indexName) + .setQuery(new MatchQueryBuilder(ForecastCommonName.FORECASTER_ID_KEY, forecasterId)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + logger.info("Delete checkpoints of forecaster {}", forecasterId); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, forecasterId); + } + // can return 0 docs get deleted because: + // 1) we cannot find matching docs + // 2) bad stats from OpenSearch. In this case, docs are deleted, but + // OpenSearch says deleted is 0. + logger.info("{} " + CheckpointDao.DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(CheckpointDao.INDEX_DELETED_LOG_MSG + " {}", forecasterId); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_CHECKPOINT_MSG, exception); + } + })); + } + + @Override + protected DeleteByQueryRequest createDeleteCheckpointRequest(String configId) { + return new DeleteByQueryRequest(indexName) + .setQuery(new MatchQueryBuilder(ForecastCommonName.FORECASTER_ID_KEY, configId)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + } + + @Override + protected ModelState fromSingleStreamModelCheckpoint(Map checkpoint, String modelId, String configId) { + + return AccessController.doPrivileged((PrivilegedAction>) () -> { + + RCFCaster rcfCaster = loadRCFCaster(checkpoint, modelId); + + ModelState modelState = new ModelState( + rcfCaster, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + loadLastProcessedSample(checkpoint, modelId), + Optional.empty(), + loadSampleQueue(checkpoint, modelId) + ); + + modelState.setLastCheckpointTime(loadTimestamp(checkpoint, modelId)); + + return modelState; + }); + } + + private RCFCaster loadRCFCaster(Map checkpoint, String modelId) { + String model = (String) checkpoint.get(CommonName.FIELD_MODEL); + if (model == null || model.length() > maxCheckpointBytes) { + logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length())); + return null; + } + return toRCFCaster(model); + } + + private Sample loadLastProcessedSample(Map checkpoint, String modelId) { + Map lastProcessedSample = (Map) checkpoint.get(LAST_PROCESSED_SAMPLE_FIELD); + if (lastProcessedSample == null || lastProcessedSample.size() == 0) { + return null; + } + + try { + return Sample.extractSample(checkpoint); + } catch (Exception e) { + logger.warn("Exception while deserializing last processed sample for " + modelId, e); + // checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return null; + } + } + + private Instant loadTimestamp(Map checkpoint, String modelId) { + String lastCheckpointTimeString = (String) (checkpoint.get(CommonName.TIMESTAMP)); + return Instant.parse(lastCheckpointTimeString); + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java b/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java new file mode 100644 index 000000000..71c413037 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java @@ -0,0 +1,166 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.config.TransformMethod; +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.calibration.Calibration; + +public class ForecastColdStart extends + ModelColdStart { + + private static final Logger logger = LogManager.getLogger(ForecastColdStart.class); + + public ForecastColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ForecastCheckpointWriteWorker checkpointWriteWorker, + int coolDownMinutes, + long rcfSeed, + int defaultTrainSamples, + int maxRoundofColdStart + ) { + // 1 means we sample all real data if possible + super( + modelTtl, + coolDownMinutes, + clock, + threadPool, + numMinSamples, + checkpointWriteWorker, + rcfSeed, + numberOfTrees, + rcfSampleSize, + thresholdMinPvalue, + rcfTimeDecay, + nodeStateManager, + 1, + defaultTrainSamples, + searchFeatureDao, + featureManager, + maxRoundofColdStart, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + AnalysisType.FORECAST + ); + } + + @Override + protected List> trainModelFromDataSegments( + Pair>, Sample> pointSamplePair, + Optional entity, + ModelState modelState, + Config config, + String taskId + ) { + List> dataPoints = pointSamplePair.getKey(); + if (dataPoints == null || dataPoints.size() == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + double[] firstPoint = dataPoints.get(0).getValue(); + if (firstPoint == null || firstPoint.length == 0) { + logger.info("Return early since data points must not be empty."); + return null; + } + + int shingleSize = config.getShingleSize(); + int forecastHorizon = ((Forecaster) config).getHorizon(); + int dimensions = firstPoint.length * shingleSize; + + RCFCaster.Builder casterBuilder = RCFCaster + .builder() + .dimensions(dimensions) + .numberOfTrees(numberOfTrees) + .shingleSize(shingleSize) + .sampleSize(rcfSampleSize) + .internalShinglingEnabled(true) + .precision(Precision.FLOAT_32) + .anomalyRate(1 - this.thresholdMinPvalue) + .outputAfter(numMinSamples) + .calibration(Calibration.MINIMAL) + .timeDecay(rcfTimeDecay) + .parallelExecutionEnabled(false) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + // the following affects the moving average in many of the transformations + // the 0.02 corresponds to a half life of 1/0.02 = 50 observations + // this is different from the timeDecay() of RCF; however it is a similar + // concept + .transformDecay(config.getTransformDecay()) + .forecastHorizon(forecastHorizon) + .initialAcceptFraction(initialAcceptFraction) + // normalize transform is required to deal with trend change in forecasting + .transformMethod(TransformMethod.NORMALIZE); + + if (rcfSeed > 0) { + casterBuilder.randomSeed(rcfSeed); + } + + RCFCaster caster = casterBuilder.build(); + + for (int i = 0; i < dataPoints.size(); i++) { + double[] dataValue = dataPoints.get(i).getValue(); + caster.process(dataValue, 0); + } + + modelState.setModel(caster); + modelState.setLastUsedTime(clock.instant()); + modelState.setLastProcessedSample(pointSamplePair.getValue()); + // save to checkpoint for real time cold start that has no taskId + if (null == taskId) { + checkpointWriteWorker.write(modelState, true, RequestPriority.MEDIUM); + } + return dataPoints; + } + + @Override + protected boolean isInterpolationInColdStartEnabled() { + return false; + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java new file mode 100644 index 000000000..13f7ceb56 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Clock; +import java.util.Locale; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.ModelManager; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ForecastDescriptor; +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastModelManager extends + ModelManager { + + public ForecastModelManager( + ForecastCheckpointDao checkpointDao, + Clock clock, + int rcfNumTrees, + int rcfNumSamplesInTree, + double rcfTimeDecay, + int rcfNumMinSamples, + ForecastColdStart entityColdStarter, + MemoryTracker memoryTracker, + FeatureManager featureManager + ) { + super( + rcfNumTrees, + rcfNumSamplesInTree, + rcfTimeDecay, + rcfNumMinSamples, + entityColdStarter, + memoryTracker, + clock, + featureManager, + checkpointDao + ); + } + + @Override + protected RCFCasterResult createEmptyResult() { + return new RCFCasterResult(null, 0, 0, 0); + } + + @Override + protected RCFCasterResult toResult(RandomCutForest forecast, RCFDescriptor castDescriptor) { + if (castDescriptor instanceof ForecastDescriptor) { + ForecastDescriptor forecastDescriptor = (ForecastDescriptor) castDescriptor; + // Use forecastDescriptor in the rest of your method + return new RCFCasterResult( + forecastDescriptor.getTimedForecast().rangeVector, + forecastDescriptor.getDataConfidence(), + forecast.getTotalUpdates(), + forecastDescriptor.getRCFScore() + ); + } else { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported type of AnomalyDescriptor : %s", castDescriptor)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java b/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java new file mode 100644 index 000000000..3584c7203 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; + +import com.amazon.randomcutforest.returntypes.RangeVector; + +public class RCFCasterResult extends IntermediateResult { + private final RangeVector forecast; + private final double dataQuality; + + public RCFCasterResult(RangeVector forecast, double dataQuality, long totalUpdates, double rcfScore) { + super(totalUpdates, rcfScore); + this.forecast = forecast; + this.dataQuality = dataQuality; + } + + public RangeVector getForecast() { + return forecast; + } + + public double getDataQuality() { + return dataQuality; + } + + @Override + public List toIndexableResults( + Config forecaster, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + List featureData, + Optional entity, + Integer schemaVersion, + String modelId, + String taskId, + String error + ) { + if (forecast.values == null || forecast.values.length == 0) { + return Collections.emptyList(); + } + return ForecastResult + .fromRawRCFCasterResult( + forecaster.getId(), + forecaster.getIntervalInMilliseconds(), + dataQuality, + featureData, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + entity, + forecaster.getUser(), + schemaVersion, + modelId, + forecast.values, + forecast.upper, + forecast.lower, + taskId + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/model/FilterBy.java b/src/main/java/org/opensearch/forecast/model/FilterBy.java new file mode 100644 index 000000000..d2be61012 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/FilterBy.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +public enum FilterBy { + BUILD_IN_QUERY, + CUSTOM_QUERY +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResult.java b/src/main/java/org/opensearch/forecast/model/ForecastResult.java index 1ce75ff63..34ff4da66 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastResult.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastResult.java @@ -68,8 +68,9 @@ public class ForecastResult extends IndexableResult { private final Instant forecastDataEndTime; private final Integer horizonIndex; protected final Double dataQuality; + private final String entityId; - // used when indexing exception or error or an empty result + // used when indexing exception or error or a feature only result public ForecastResult( String forecasterId, String taskId, @@ -81,8 +82,7 @@ public ForecastResult( String error, Optional entity, User user, - Integer schemaVersion, - String modelId + Integer schemaVersion ) { this( forecasterId, @@ -97,7 +97,6 @@ public ForecastResult( entity, user, schemaVersion, - modelId, null, null, null, @@ -121,7 +120,6 @@ public ForecastResult( Optional entity, User user, Integer schemaVersion, - String modelId, String featureId, Float forecastValue, Float lowerBound, @@ -141,7 +139,6 @@ public ForecastResult( entity, user, schemaVersion, - modelId, taskId ); this.featureId = featureId; @@ -149,10 +146,11 @@ public ForecastResult( this.forecastValue = forecastValue; this.lowerBound = lowerBound; this.upperBound = upperBound; - this.confidenceIntervalWidth = lowerBound != null && upperBound != null ? Math.abs(upperBound - lowerBound) : Float.NaN; + this.confidenceIntervalWidth = safeAbsoluteDifference(lowerBound, upperBound); this.forecastDataStartTime = forecastDataStartTime; this.forecastDataEndTime = forecastDataEndTime; this.horizonIndex = horizonIndex; + this.entityId = getEntityId(entity, configId); } public static List fromRawRCFCasterResult( @@ -175,9 +173,13 @@ public static List fromRawRCFCasterResult( String taskId ) { int inputLength = featureData.size(); - int numberOfForecasts = forecastsValues.length / inputLength; + int numberOfForecasts = 0; + if (forecastsValues != null) { + numberOfForecasts = forecastsValues.length / inputLength; + } - List convertedForecastValues = new ArrayList<>(numberOfForecasts); + // +1 for actual value + List convertedForecastValues = new ArrayList<>(numberOfForecasts + 1); // store feature data and forecast value separately for easy query on feature data // we can join them using forecasterId, entityId, and executionStartTime/executionEndTime @@ -196,7 +198,6 @@ public static List fromRawRCFCasterResult( entity, user, schemaVersion, - modelId, null, null, null, @@ -219,22 +220,22 @@ public static List fromRawRCFCasterResult( taskId, dataQuality, null, - null, - null, + dataStartTime, + dataEndTime, executionStartTime, executionEndTime, error, entity, user, schemaVersion, - modelId, featureData.get(j).getFeatureId(), forecastsValues[k], forecastsLowers[k], forecastsUppers[k], forecastDataStartTime, forecastDataEndTime, - i + // horizon starts from 1 + i + 1 ) ); } @@ -255,6 +256,7 @@ public ForecastResult(StreamInput input) throws IOException { this.forecastDataStartTime = input.readOptionalInstant(); this.forecastDataEndTime = input.readOptionalInstant(); this.horizonIndex = input.readOptionalInt(); + this.entityId = input.readOptionalString(); } @Override @@ -286,14 +288,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(CommonName.ERROR_FIELD, error); } if (optionalEntity.isPresent()) { - xContentBuilder.field(CommonName.ENTITY_FIELD, optionalEntity.get()); + xContentBuilder.field(CommonName.ENTITY_KEY, optionalEntity.get()); } if (user != null) { xContentBuilder.field(CommonName.USER_FIELD, user); } - if (modelId != null) { - xContentBuilder.field(CommonName.MODEL_ID_FIELD, modelId); - } if (dataQuality != null && !dataQuality.isNaN()) { xContentBuilder.field(CommonName.DATA_QUALITY_FIELD, dataQuality); } @@ -312,13 +311,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (upperBound != null) { xContentBuilder.field(UPPER_BOUND_FIELD, upperBound); } + if (confidenceIntervalWidth != null) { + xContentBuilder.field(INTERVAL_WIDTH_FIELD, confidenceIntervalWidth); + } if (forecastDataStartTime != null) { xContentBuilder.field(FORECAST_DATA_START_TIME_FIELD, forecastDataStartTime.toEpochMilli()); } if (forecastDataEndTime != null) { xContentBuilder.field(FORECAST_DATA_END_TIME_FIELD, forecastDataEndTime.toEpochMilli()); } - if (horizonIndex != null) { + // the document with the actual value should not contain horizonIndex + // its horizonIndex is -1. Actual forecast value starts from horizon index 1 + if (horizonIndex != null && horizonIndex > 0) { xContentBuilder.field(HORIZON_INDEX_FIELD, horizonIndex); } if (featureId != null) { @@ -340,7 +344,6 @@ public static ForecastResult parse(XContentParser parser) throws IOException { Entity entity = null; User user = null; Integer schemaVersion = CommonValue.NO_SCHEMA_VERSION; - String modelId = null; String taskId = null; String featureId = null; @@ -385,7 +388,7 @@ public static ForecastResult parse(XContentParser parser) throws IOException { case CommonName.ERROR_FIELD: error = parser.text(); break; - case CommonName.ENTITY_FIELD: + case CommonName.ENTITY_KEY: entity = Entity.parse(parser); break; case CommonName.USER_FIELD: @@ -394,9 +397,6 @@ public static ForecastResult parse(XContentParser parser) throws IOException { case CommonName.SCHEMA_VERSION_FIELD: schemaVersion = parser.intValue(); break; - case CommonName.MODEL_ID_FIELD: - modelId = parser.text(); - break; case FEATURE_ID_FIELD: featureId = parser.text(); break; @@ -440,7 +440,6 @@ public static ForecastResult parse(XContentParser parser) throws IOException { Optional.ofNullable(entity), user, schemaVersion, - modelId, featureId, forecastValue, lowerBound, @@ -469,7 +468,8 @@ public boolean equals(Object o) { && Objects.equal(confidenceIntervalWidth, that.confidenceIntervalWidth) && Objects.equal(forecastDataStartTime, that.forecastDataStartTime) && Objects.equal(forecastDataEndTime, that.forecastDataEndTime) - && Objects.equal(horizonIndex, that.horizonIndex); + && Objects.equal(horizonIndex, that.horizonIndex) + && Objects.equal(entityId, that.entityId); } @Generated @@ -487,7 +487,8 @@ public int hashCode() { confidenceIntervalWidth, forecastDataStartTime, forecastDataEndTime, - horizonIndex + horizonIndex, + entityId ); return result; } @@ -507,6 +508,7 @@ public String toString() { .append("forecastDataStartTime", forecastDataStartTime) .append("forecastDataEndTime", forecastDataEndTime) .append("horizonIndex", horizonIndex) + .append("entityId", entityId) .toString(); } @@ -523,6 +525,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInstant(forecastDataStartTime); out.writeOptionalInstant(forecastDataEndTime); out.writeOptionalInt(horizonIndex); + out.writeOptionalString(entityId); } public static ForecastResult getDummyResult() { @@ -537,8 +540,7 @@ public static ForecastResult getDummyResult() { null, Optional.empty(), null, - CommonValue.NO_SCHEMA_VERSION, - null + CommonValue.NO_SCHEMA_VERSION ); } @@ -587,4 +589,43 @@ public Instant getForecastDataEndTime() { public Integer getHorizonIndex() { return horizonIndex; } + + public String getEntityId() { + return entityId; + } + + /** + * Safely calculates the absolute difference between two Float values. + * + *

This method handles potential null values, as well as special Float values + * like NaN, Infinity, and -Infinity. If either of the input values is null, + * the method returns null. If the difference results in NaN or Infinity values, + * the method returns Float.MAX_VALUE. + * + *

Note: Float.MIN_VALUE is considered the smallest positive nonzero value + * of type float. The smallest negative value is -Float.MAX_VALUE. + * + * @param a The first Float value. + * @param b The second Float value. + * @return The absolute difference between the two values, or null if any input is null. + * If the result is NaN or Infinity, returns Float.MAX_VALUE. + */ + public Float safeAbsoluteDifference(Float a, Float b) { + // Check for null values + if (a == null || b == null) { + return null; // or throw an exception, or handle as per your requirements + } + + // Calculate the difference + float diff = a - b; + + // Check for special values + if (Float.isNaN(diff) || Float.isInfinite(diff)) { + return Float.MAX_VALUE; // or handle in any other way you see fit + } + + // Return the absolute difference + return Math.abs(diff); + } + } diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResultBucket.java b/src/main/java/org/opensearch/forecast/model/ForecastResultBucket.java new file mode 100644 index 000000000..aa3dc21db --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastResultBucket.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; +import java.util.Map; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; + +public class ForecastResultBucket implements ToXContentObject, Writeable { + public static final String BUCKETS_FIELD = "buckets"; + public static final String KEY_FIELD = "key"; + public static final String DOC_COUNT_FIELD = "doc_count"; + public static final String BUCKET_INDEX_FIELD = "bucket_index"; + + // e.g., "ip": "1.2.3.4" + private final Map key; + private final int docCount; + private final Map aggregations; + private final int bucketIndex; + + public ForecastResultBucket(Map key, int docCount, Map aggregations, int bucketIndex) { + this.key = key; + this.docCount = docCount; + this.aggregations = aggregations; + this.bucketIndex = bucketIndex; + } + + public ForecastResultBucket(StreamInput input) throws IOException { + this.key = input.readMap(); + this.docCount = input.readInt(); + this.aggregations = input.readMap(StreamInput::readString, StreamInput::readDouble); + this.bucketIndex = input.readInt(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(KEY_FIELD, key) + .field(DOC_COUNT_FIELD, docCount) + .field(BUCKET_INDEX_FIELD, bucketIndex); + + for (Map.Entry entry : aggregations.entrySet()) { + xContentBuilder.field(entry.getKey(), entry.getValue()); + } + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(key); + out.writeInt(docCount); + out.writeMap(aggregations, StreamOutput::writeString, StreamOutput::writeDouble); + out.writeInt(bucketIndex); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ForecastResultBucket that = (ForecastResultBucket) o; + return Objects.equal(key, that.getKey()) + && Objects.equal(docCount, that.getDocCount()) + && Objects.equal(aggregations, that.getAggregations()) + && Objects.equal(bucketIndex, that.getBucketIndex()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(key, docCount, aggregations, bucketIndex); + } + + @Generated + @Override + public String toString() { + return new ToStringBuilder(this) + .append("key", key) + .append("docCount", docCount) + .append("aggregations", aggregations) + .append("bucketIndex", bucketIndex) + .toString(); + } + + public Map getKey() { + return key; + } + + public int getDocCount() { + return docCount; + } + + public Map getAggregations() { + return aggregations; + } + + public int getBucketIndex() { + return bucketIndex; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTask.java b/src/main/java/org/opensearch/forecast/model/ForecastTask.java index 4d7e889d7..27131131a 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastTask.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastTask.java @@ -1,6 +1,17 @@ /* +<<<<<<< HEAD * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 +======= + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. +>>>>>>> f22eaa95 (test) */ package org.opensearch.forecast.model; @@ -128,8 +139,9 @@ public static Builder builder() { } @Override - public boolean isEntityTask() { - return ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY.name().equals(taskType); + public boolean isHistoricalEntityTask() { + // we have no backtesting + return false; } public static class Builder extends TimeSeriesTask.Builder { @@ -324,7 +336,8 @@ public static ForecastTask parse(XContentParser parser, String taskId) throws IO forecaster.getUser(), forecaster.getCustomResultIndex(), forecaster.getHorizon(), - forecaster.getImputationOption() + forecaster.getImputationOption(), + forecaster.getTransformDecay() ); return new Builder() .taskId(parsedTaskId) diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTaskProfile.java b/src/main/java/org/opensearch/forecast/model/ForecastTaskProfile.java new file mode 100644 index 000000000..fbde0b7d5 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastTaskProfile.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.TaskProfile; + +public class ForecastTaskProfile extends TaskProfile { + + public ForecastTaskProfile( + ForecastTask forecastTask, + Integer shingleSize, + Long rcfTotalUpdates, + Long modelSizeInBytes, + String nodeId, + String taskId, + String taskType + ) { + super(forecastTask, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, taskType); + } + + public ForecastTaskProfile(StreamInput input) throws IOException { + if (input.readBoolean()) { + this.task = new ForecastTask(input); + } else { + this.task = null; + } + this.shingleSize = input.readOptionalInt(); + this.rcfTotalUpdates = input.readOptionalLong(); + this.modelSizeInBytes = input.readOptionalLong(); + this.nodeId = input.readOptionalString(); + this.taskId = input.readOptionalString(); + this.taskType = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (task != null) { + out.writeBoolean(true); + task.writeTo(out); + } else { + out.writeBoolean(false); + } + + out.writeOptionalInt(shingleSize); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalLong(modelSizeInBytes); + out.writeOptionalString(nodeId); + out.writeOptionalString(taskId); + out.writeOptionalString(taskType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + super.toXContent(xContentBuilder); + return xContentBuilder.endObject(); + } + + public static ForecastTaskProfile parse(XContentParser parser) throws IOException { + ForecastTask forecastTask = null; + Integer shingleSize = null; + Long rcfTotalUpdates = null; + Long modelSizeInBytes = null; + String nodeId = null; + String taskId = null; + String taskType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case ForecastCommonName.FORECAST_TASK: + forecastTask = ForecastTask.parse(parser); + break; + case SHINGLE_SIZE_FIELD: + shingleSize = parser.intValue(); + break; + case RCF_TOTAL_UPDATES_FIELD: + rcfTotalUpdates = parser.longValue(); + break; + case MODEL_SIZE_IN_BYTES: + modelSizeInBytes = parser.longValue(); + break; + case NODE_ID_FIELD: + nodeId = parser.text(); + break; + case TASK_ID_FIELD: + taskId = parser.text(); + break; + case TASK_TYPE_FIELD: + taskType = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return new ForecastTaskProfile(forecastTask, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, taskType); + } + + @Override + protected String getTaskFieldName() { + return ForecastCommonName.FORECAST_TASK; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java index 76e1aac88..16cbf902a 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java @@ -25,45 +25,33 @@ * to single-stream forecasting, and two tasks for HC, one at the forecaster level and another at the entity level. * * Real-time forecasting: - * - FORECAST_REALTIME_SINGLE_STREAM: Represents a task type for single-stream forecasting. Ideal for scenarios where a single + * - REALTIME_FORECAST_SINGLE_STREAM: Represents a task type for single-stream forecasting. Ideal for scenarios where a single * time series is processed in real-time. - * - FORECAST_REALTIME_HC_FORECASTER: Represents a task type for high cardinality (HC) forecasting. Used when dealing with a + * - REALTIME_FORECAST_HC_FORECASTER: Represents a task type for high cardinality (HC) forecasting. Used when dealing with a * large number of distinct entities in real-time. * - * Historical forecasting: - * - FORECAST_HISTORICAL_SINGLE_STREAM: Represents a forecaster-level task for single-stream historical forecasting. - * Suitable for analyzing a single time series in a sequential manner. - * - FORECAST_HISTORICAL_HC_FORECASTER: A forecaster-level task to track overall state, initialization progress, errors, etc., - * for HC forecasting. Central to managing multiple historical time series with high cardinality. - * - FORECAST_HISTORICAL_HC_ENTITY: An entity-level task to track the state, initialization progress, errors, etc., of a - * specific entity within HC historical forecasting. Allows for fine-grained information recording at the entity level. + * Run once forecasting: + * - RUN_ONCE_FORECAST_SINGLE_STREAM: forecast once in single-stream scenario. + * - RUN_ONCE_FORECAST_HC_FORECASTER: forecast once in HC scenario. + * + * enum names need to start with REALTIME or HISTORICAL we use prefix in TaskManager to check if a task is of certain type (e.g., historical) * */ public enum ForecastTaskType implements TaskType { - FORECAST_REALTIME_SINGLE_STREAM, - FORECAST_REALTIME_HC_FORECASTER, - FORECAST_HISTORICAL_SINGLE_STREAM, - // forecaster level task to track overall state, init progress, error etc. for HC forecaster - FORECAST_HISTORICAL_HC_FORECASTER, - // entity level task to track just one specific entity's state, init progress, error etc. - FORECAST_HISTORICAL_HC_ENTITY; + REALTIME_FORECAST_SINGLE_STREAM, + REALTIME_FORECAST_HC_FORECASTER, + RUN_ONCE_FORECAST_SINGLE_STREAM, + RUN_ONCE_FORECAST_HC_FORECASTER; - public static List HISTORICAL_FORECASTER_TASK_TYPES = ImmutableList - .of(ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM); - public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList - .of( - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY - ); public static List REALTIME_TASK_TYPES = ImmutableList - .of(ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER); + .of(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM, ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER); public static List ALL_FORECAST_TASK_TYPES = ImmutableList .of( - ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, - ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + REALTIME_FORECAST_SINGLE_STREAM, + REALTIME_FORECAST_HC_FORECASTER, + RUN_ONCE_FORECAST_SINGLE_STREAM, + RUN_ONCE_FORECAST_HC_FORECASTER ); + public static List RUN_ONCE_TASK_TYPES = ImmutableList + .of(ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM, ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER); } diff --git a/src/main/java/org/opensearch/forecast/model/Forecaster.java b/src/main/java/org/opensearch/forecast/model/Forecaster.java index c572c28db..b4b116d9e 100644 --- a/src/main/java/org/opensearch/forecast/model/Forecaster.java +++ b/src/main/java/org/opensearch/forecast/model/Forecaster.java @@ -28,6 +28,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.forecast.constant.ForecastCommonMessages; import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.settings.ForecastSettings; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.timeseries.common.exception.ValidationException; @@ -85,7 +86,8 @@ public Forecaster( User user, String resultIndex, Integer horizon, - ImputationOption imputationOption + ImputationOption imputationOption, + Double transformDecay ) { super( forecasterId, @@ -105,34 +107,53 @@ public Forecaster( user, resultIndex, forecastInterval, - imputationOption + imputationOption, + transformDecay ); checkAndThrowValidationErrors(ValidationAspect.FORECASTER); if (forecastInterval == null) { - errorMessage = ForecastCommonMessages.NULL_FORECAST_INTERVAL; - issueType = ValidationIssueType.FORECAST_INTERVAL; + throw new ValidationException( + ForecastCommonMessages.NULL_FORECAST_INTERVAL, + ValidationIssueType.FORECAST_INTERVAL, + ValidationAspect.FORECASTER + ); } else if (((IntervalTimeConfiguration) forecastInterval).getInterval() <= 0) { - errorMessage = ForecastCommonMessages.INVALID_FORECAST_INTERVAL; - issueType = ValidationIssueType.FORECAST_INTERVAL; + throw new ValidationException( + ForecastCommonMessages.INVALID_FORECAST_INTERVAL, + ValidationIssueType.FORECAST_INTERVAL, + ValidationAspect.FORECASTER + ); } int maxCategoryFields = ForecastNumericSetting.maxCategoricalFields(); if (categoryFields != null && categoryFields.size() > maxCategoryFields) { - errorMessage = CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields); - issueType = ValidationIssueType.CATEGORY; + throw new ValidationException( + CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), + ValidationIssueType.CATEGORY, + ValidationAspect.FORECASTER + ); } if (invalidHorizon(horizon)) { - errorMessage = "Horizon size must be a positive integer no larger than " - + TimeSeriesSettings.MAX_SHINGLE_SIZE * DEFAULT_HORIZON_SHINGLE_RATIO - + ". Got " - + horizon; - issueType = ValidationIssueType.SHINGLE_SIZE_FIELD; + throw new ValidationException( + "Horizon size must be a positive integer no larger than " + + TimeSeriesSettings.MAX_SHINGLE_SIZE * DEFAULT_HORIZON_SHINGLE_RATIO + + ". Got " + + horizon, + ValidationIssueType.HORIZON_SIZE, + ValidationAspect.FORECASTER + ); } - checkAndThrowValidationErrors(ValidationAspect.FORECASTER); + if (shingleSize < 4) { + throw new ValidationException( + "Shingle size must be no less than " + ForecastSettings.MINIMUM_SHINLE_SIZE + ". Got " + shingleSize, + ValidationIssueType.SHINGLE_SIZE_FIELD, + ValidationAspect.FORECASTER + ); + } this.horizon = horizon; } @@ -220,7 +241,8 @@ public static Forecaster parse( List categoryField = null; Integer horizon = null; - ImputationOption interpolationOption = null; + ImputationOption imputationOption = null; + Double transformDecay = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -329,7 +351,10 @@ public static Forecaster parse( horizon = parser.intValue(); break; case IMPUTATION_OPTION_FIELD: - interpolationOption = ImputationOption.parse(parser); + imputationOption = ImputationOption.parse(parser); + break; + case TRANSFORM_DECAY_FIELD: + transformDecay = parser.doubleValue(); break; default: parser.skipChildren(); @@ -355,7 +380,8 @@ public static Forecaster parse( user, resultIndex, horizon, - interpolationOption + imputationOption, + transformDecay ); return forecaster; } diff --git a/src/main/java/org/opensearch/forecast/model/ForecasterProfile.java b/src/main/java/org/opensearch/forecast/model/ForecasterProfile.java new file mode 100644 index 000000000..809409d0d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecasterProfile.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.model.ConfigProfile; + +public class ForecasterProfile extends ConfigProfile { + + public static class Builder extends ConfigProfile.Builder { + private ForecastTaskProfile forecastTaskProfile; + + public Builder() {} + + @Override + public Builder taskProfile(ForecastTaskProfile forecastTaskProfile) { + this.forecastTaskProfile = forecastTaskProfile; + return this; + } + + @Override + public ForecasterProfile build() { + ForecasterProfile profile = new ForecasterProfile(); + profile.state = state; + profile.error = error; + profile.modelProfile = modelProfile; + profile.modelCount = modelCount; + profile.shingleSize = shingleSize; + profile.coordinatingNode = coordinatingNode; + profile.totalSizeInBytes = totalSizeInBytes; + profile.initProgress = initProgress; + profile.totalEntities = totalEntities; + profile.activeEntities = activeEntities; + profile.taskProfile = forecastTaskProfile; + + return profile; + } + } + + public ForecasterProfile() {} + + public ForecasterProfile(StreamInput in) throws IOException { + super(in); + } + + @Override + protected ForecastTaskProfile createTaskProfile(StreamInput in) throws IOException { + return new ForecastTaskProfile(in); + } + + @Override + protected String getTaskFieldName() { + return ForecastCommonName.FORECAST_TASK; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/Order.java b/src/main/java/org/opensearch/forecast/model/Order.java new file mode 100644 index 000000000..6471bd75a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/Order.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +public enum Order { + ASC, + DESC +} diff --git a/src/main/java/org/opensearch/forecast/model/Subaggregation.java b/src/main/java/org/opensearch/forecast/model/Subaggregation.java new file mode 100644 index 000000000..376b0226b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/Subaggregation.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +public class Subaggregation implements Writeable, ToXContentObject { + private static final String AGGREGATION_QUERY = "aggregation_query"; + private static final String ORDER = "order"; + + private final AggregationBuilder aggregation; + private final Order order; + + public Subaggregation(AggregationBuilder aggregation, Order order) { + super(); + this.aggregation = aggregation; + this.order = order; + } + + public Subaggregation(StreamInput input) throws IOException { + this.aggregation = input.readNamedWriteable(AggregationBuilder.class); + this.order = input.readEnum(Order.class); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(ORDER, order.name()) + .field(AGGREGATION_QUERY) + .startObject() + .value(aggregation) + .endObject(); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(aggregation); + out.writeEnum(order); + } + + /** + * Parse raw json content into Subaggregation instance. + * + * @param parser json based content parser + * @return feature instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Subaggregation parse(XContentParser parser) throws IOException { + Order order = Order.ASC; + AggregationBuilder aggregation = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + + parser.nextToken(); + switch (fieldName) { + case ORDER: + order = Order.valueOf(parser.text()); + break; + case AGGREGATION_QUERY: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + aggregation = ParseUtils.toAggregationBuilder(parser); + break; + default: + break; + } + } + return new Subaggregation(aggregation, order); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Subaggregation feature = (Subaggregation) o; + return Objects.equal(order, feature.getOrder()) && Objects.equal(aggregation, feature.getAggregation()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(aggregation, order); + } + + public AggregationBuilder getAggregation() { + return aggregation; + } + + public Order getOrder() { + return order; + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java new file mode 100644 index 000000000..f0c617861 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java @@ -0,0 +1,86 @@ +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.Optional; +import java.util.Random; +import java.util.function.Function; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; +import org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker; + +public class ForecastCheckpointMaintainWorker extends CheckpointMaintainWorker { + public static final String WORKER_NAME = "forecast-checkpoint-maintain"; + + public ForecastCheckpointMaintainWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + RateLimitedRequestWorker targetQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Function> converter + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + targetQueue, + stateTtl, + nodeStateManager, + converter, + AnalysisType.FORECAST + ); + + this.batchSize = FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS + .get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + it -> this.expectedExecutionTimeInMilliSecsPerRequest = it + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java new file mode 100644 index 000000000..5bbcb4e1e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.stats.StatNames; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCheckpointReadWorker extends + CheckpointReadWorker { + public static final String WORKER_NAME = "forecast-checkpoint-read"; + + public ForecastCheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastModelManager modelManager, + ForecastCheckpointDao checkpointDao, + ForecastColdStartWorker entityColdStartQueue, + NodeStateManager stateManager, + ForecastIndexManagement indexUtil, + Provider cacheProvider, + Duration stateTtl, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastStats forecastStats, + ForecastSaveResultStrategy saveResultStrategy + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + executionTtl, + modelManager, + checkpointDao, + entityColdStartQueue, + stateManager, + indexUtil, + cacheProvider, + stateTtl, + checkpointWriteQueue, + forecastStats, + FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY, + FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME, + StatNames.FORECAST_MODEL_CORRUTPION_COUNT, + AnalysisType.FORECAST, + saveResultStrategy + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java new file mode 100644 index 000000000..e1e6b1903 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java @@ -0,0 +1,78 @@ +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCheckpointWriteWorker extends + CheckpointWriteWorker { + public static final String WORKER_NAME = "forecast-checkpoint-write"; + + public ForecastCheckpointWriteWorker( + long heapSize, + int singleRequestSize, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastCheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager timeSeriesNodeStateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSize, + singleRequestSize, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + timeSeriesNodeStateManager, + checkpoint, + indexName, + checkpointInterval, + AnalysisType.FORECAST + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java new file mode 100644 index 000000000..43831f8df --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * A queue slowly releasing low-priority requests to CheckpointReadQueue + * + * ColdEntityQueue is a queue to absorb cold entities. Like hot entities, we load a cold + * entity's model checkpoint from disk, train models if the checkpoint is not found, + * query for missed features to complete a shingle, use the models to check whether + * the incoming feature is normal, update models, and save the detection results to disks.  + * Implementation-wise, we reuse the queues we have developed for hot entities. + * The differences are: we process hot entities as long as resources (e.g., AD + * thread pool has availability) are available, while we release cold entity requests + * to other queues at a slow controlled pace. Also, cold entity requests' priority is low. + * So only when there are no hot entity requests to process are we going to process cold + * entity requests.  + * + */ +public class ForecastColdEntityWorker extends + ColdEntityWorker { + public static final String WORKER_NAME = "forecast-cold-entity"; + + public ForecastColdEntityWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService forecastCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + ForecastCheckpointReadWorker checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + forecastCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointReadQueue, + stateTtl, + nodeStateManager, + FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnalysisType.FORECAST + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java new file mode 100644 index 000000000..8aff62157 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.util.ParseUtils; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastColdStartWorker extends + ColdStartWorker { + public static final String WORKER_NAME = "forecast-hc-cold-start"; + + public ForecastColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService circuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastColdStart coldStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + ForecastPriorityCache cacheProvider, + ForecastModelManager forecastModelManager, + ForecastSaveResultStrategy saveStrategy + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + circuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + coldStarter, + stateTtl, + nodeStateManager, + cacheProvider, + AnalysisType.FORECAST, + forecastModelManager, + saveStrategy + ); + } + + @Override + protected ModelState createEmptyState(FeatureRequest coldStartRequest, String modelId, String configId) { + return new ModelState( + null, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + new Sample(), + coldStartRequest.getEntity(), + new ArrayDeque<>() + ); + } + + @Override + protected ForecastResult createIndexableResult( + Config config, + String taskId, + String modelId, + Entry entry, + Optional entity + ) { + return new ForecastResult( + config.getId(), + taskId, + ParseUtils.getFeatureData(entry.getValue(), config), + Instant.ofEpochMilli(entry.getKey() - config.getIntervalInMilliseconds()), + Instant.ofEpochMilli(entry.getKey()), + Instant.now(), + Instant.now(), + "", + entity, + config.getUser(), + config.getSchemaVersion() + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java new file mode 100644 index 000000000..f9fd07c25 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ratelimit; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ForecastResultWriteRequest extends ResultWriteRequest { + + public ForecastResultWriteRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + ForecastResult result, + String resultIndex + ) { + super(expirationEpochMs, detectorId, priority, result, resultIndex); + } + + public ForecastResultWriteRequest(StreamInput in) throws IOException { + super(in, ForecastResult::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java new file mode 100644 index 000000000..7f991bcf6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; + +public class ForecastResultWriteWorker extends + ResultWriteWorker { + public static final String WORKER_NAME = "forecast-result-write"; + + public ForecastResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastIndexMemoryPressureAwareResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager, + resultHandler, + xContentRegistry, + ForecastResult::parse, + AnalysisType.FORECAST + ); + } + + @Override + protected ForecastResultBulkRequest toBatchRequest(List toProcess) { + final ForecastResultBulkRequest bulkRequest = new ForecastResultBulkRequest(); + for (ForecastResultWriteRequest request : toProcess) { + bulkRequest.add(request); + } + return bulkRequest; + } + + @Override + protected ForecastResultWriteRequest createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + ForecastResult result, + String resultIndex + ) { + return new ForecastResultWriteRequest(expirationEpochMs, configId, priority, result, resultIndex); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastSaveResultStrategy.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastSaveResultStrategy.java new file mode 100644 index 000000000..1dc3029a0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastSaveResultStrategy.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.util.ParseUtils; + +public class ForecastSaveResultStrategy implements SaveResultStrategy { + private int resultMappingVersion; + private ForecastResultWriteWorker resultWriteWorker; + + public ForecastSaveResultStrategy(int resultMappingVersion, ForecastResultWriteWorker resultWriteWorker) { + this.resultMappingVersion = resultMappingVersion; + this.resultWriteWorker = resultWriteWorker; + } + + @Override + public void saveResult(RCFCasterResult result, Config config, FeatureRequest origRequest, String modelId) { + saveResult( + result, + config, + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()), + modelId, + origRequest.getCurrentFeature(), + origRequest.getEntity(), + origRequest.getTaskId() + ); + } + + @Override + public void saveResult( + RCFCasterResult result, + Config config, + Instant dataStart, + Instant dataEnd, + String modelId, + double[] currentData, + Optional entity, + String taskId + ) { + if (result != null && result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + dataStart, + dataEnd, + Instant.now(), + Instant.now(), + ParseUtils.getFeatureData(currentData, config), + entity, + resultMappingVersion, + modelId, + taskId, + null + ); + + for (ForecastResult r : indexableResults) { + saveResult(r, config); + } + } + } + + @Override + public void saveResult(ForecastResult result, Config config) { + resultWriteWorker + .put( + new ForecastResultWriteRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + config.getId(), + RequestPriority.MEDIUM, + result, + config.getCustomResultIndex() + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/AbstractForecastSearchAction.java b/src/main/java/org/opensearch/forecast/rest/AbstractForecastSearchAction.java new file mode 100644 index 000000000..0146981ca --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/AbstractForecastSearchAction.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest; + +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.timeseries.AbstractSearchAction; + +public abstract class AbstractForecastSearchAction extends AbstractSearchAction { + + public AbstractForecastSearchAction( + List urlPaths, + List> deprecatedPaths, + String index, + Class clazz, + ActionType actionType + ) { + super( + urlPaths, + deprecatedPaths, + index, + clazz, + actionType, + ForecastEnabledSetting::isForecastEnabled, + ForecastCommonMessages.DISABLED_ERR_MSG + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java new file mode 100644 index 000000000..bac785c79 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INTERVAL; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_WINDOW_DELAY; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_FORECAST_FEATURES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_HC_FORECASTERS; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.rest.BaseRestHandler; + +/** + * This class consists of the base class for validating and indexing forecast REST handlers. + */ +public abstract class AbstractForecasterAction extends BaseRestHandler { + /** + * Timeout duration for the forecast request. + */ + protected volatile TimeValue requestTimeout; + + /** + * Interval at which forecasts are generated. + */ + protected volatile TimeValue forecastInterval; + + /** + * Delay duration before the forecast window begins. + */ + protected volatile TimeValue forecastWindowDelay; + + /** + * Maximum number of single stream forecasters allowed. + */ + protected volatile Integer maxSingleStreamForecasters; + + /** + * Maximum number of high-cardinality (HC) forecasters allowed. + */ + protected volatile Integer maxHCForecasters; + + /** + * Maximum number of features to be used for forecasting. + */ + protected volatile Integer maxForecastFeatures; + + /** + * Maximum number of categorical fields allowed. + */ + protected volatile Integer maxCategoricalFields; + + /** + * Constructor for the base class for validating and indexing forecast REST handlers. + * + * @param settings Settings for the forecast plugin. + * @param clusterService Cluster service. + */ + public AbstractForecasterAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = FORECAST_REQUEST_TIMEOUT.get(settings); + this.forecastInterval = FORECAST_INTERVAL.get(settings); + this.forecastWindowDelay = FORECAST_WINDOW_DELAY.get(settings); + this.maxSingleStreamForecasters = MAX_SINGLE_STREAM_FORECASTERS.get(settings); + this.maxHCForecasters = MAX_HC_FORECASTERS.get(settings); + this.maxForecastFeatures = MAX_FORECAST_FEATURES; + this.maxCategoricalFields = ForecastNumericSetting.maxCategoricalFields(); + // TODO: will add more cluster setting consumer later + // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_REQUEST_TIMEOUT, it -> requestTimeout = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INTERVAL, it -> forecastInterval = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_WINDOW_DELAY, it -> forecastWindowDelay = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_SINGLE_STREAM_FORECASTERS, it -> maxSingleStreamForecasters = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_HC_FORECASTERS, it -> maxHCForecasters = it); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java b/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java new file mode 100644 index 000000000..9ba626fcd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java @@ -0,0 +1,141 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Input data needed to trigger forecaster. + */ +public class ForecasterExecutionInput implements ToXContentObject { + + private static final String FORECASTER_ID_FIELD = "forecaster_id"; + private static final String PERIOD_START_FIELD = "period_start"; + private static final String PERIOD_END_FIELD = "period_end"; + private static final String FORECASTER_FIELD = "forecaster"; + private Instant periodStart; + private Instant periodEnd; + private String forecasterId; + private Forecaster forecaster; + + public ForecasterExecutionInput(String forecasterId, Instant periodStart, Instant periodEnd, Forecaster forecaster) { + this.periodStart = periodStart; + this.periodEnd = periodEnd; + this.forecasterId = forecasterId; + this.forecaster = forecaster; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(FORECASTER_ID_FIELD, forecasterId) + .field(PERIOD_START_FIELD, periodStart.toEpochMilli()) + .field(PERIOD_END_FIELD, periodEnd.toEpochMilli()) + .field(FORECASTER_FIELD, forecaster); + return xContentBuilder.endObject(); + } + + public static ForecasterExecutionInput parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static ForecasterExecutionInput parse(XContentParser parser, String inputConfigId) throws IOException { + Instant periodStart = null; + Instant periodEnd = null; + Forecaster forecaster = null; + String forecasterId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case FORECASTER_ID_FIELD: + forecasterId = parser.text(); + break; + case PERIOD_START_FIELD: + periodStart = ParseUtils.toInstant(parser); + break; + case PERIOD_END_FIELD: + periodEnd = ParseUtils.toInstant(parser); + break; + case FORECASTER_FIELD: + XContentParser.Token token = parser.currentToken(); + if (parser.currentToken().equals(XContentParser.Token.START_OBJECT)) { + forecaster = Forecaster.parse(parser, forecasterId); + } + break; + default: + break; + } + } + if (!Strings.isNullOrEmpty(inputConfigId)) { + forecasterId = inputConfigId; + } + return new ForecasterExecutionInput(forecasterId, periodStart, periodEnd, forecaster); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ForecasterExecutionInput that = (ForecasterExecutionInput) o; + return Objects.equal(periodStart, that.periodStart) + && Objects.equal(periodEnd, that.periodEnd) + && Objects.equal(forecasterId, that.forecasterId) + && Objects.equal(forecaster, that.forecaster); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(periodStart, periodEnd, forecasterId); + } + + public Instant getPeriodStart() { + return periodStart; + } + + public Instant getPeriodEnd() { + return periodEnd; + } + + public String getForecasterId() { + return forecasterId; + } + + public void setForecasterId(String forecasterId) { + this.forecasterId = forecasterId; + } + + public Forecaster getForecaster() { + return forecaster; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestDeleteForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestDeleteForecasterAction.java new file mode 100644 index 000000000..54be0bb31 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestDeleteForecasterAction.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.DeleteForecasterAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.DeleteConfigRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +public class RestDeleteForecasterAction extends BaseRestHandler { + public static final String DELETE_FORECASTER_ACTION = "delete_forecaster"; + + public RestDeleteForecasterAction() {} + + @Override + public String getName() { + return DELETE_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + DeleteConfigRequest deleteForecasterRequest = new DeleteConfigRequest(forecasterId); + return channel -> client + .execute(DeleteForecasterAction.INSTANCE, deleteForecasterRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + + } + + @Override + public List routes() { + return ImmutableList + .of( + // delete forecaster document + new Route( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java b/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java new file mode 100644 index 000000000..a5f98829b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ForecasterJobAction; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.rest.RestJobAction; +import org.opensearch.timeseries.transport.JobRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +public class RestForecasterJobAction extends RestJobAction { + public static final String FORECAST_JOB_ACTION = "forecaster_job_action"; + + @Override + public String getName() { + return FORECAST_JOB_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + String rawPath = request.rawPath(); + DateRange dateRange = parseInputDateRange(request); + + // false means we don't support backtesting and thus no need to stop backtesting + JobRequest forecasterJobRequest = new JobRequest(forecasterId, dateRange, false, rawPath); + + return channel -> client.execute(ForecasterJobAction.INSTANCE, forecasterJobRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + + } + + @Override + public List routes() { + return ImmutableList + .of( + /// start forecaster Job + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, START_JOB) + ), + /// stop forecaster Job + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, STOP_JOB) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestForecasterSuggestAction.java b/src/main/java/org/opensearch/forecast/rest/RestForecasterSuggestAction.java new file mode 100644 index 000000000..a7b9b8132 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestForecasterSuggestAction.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; +import static org.opensearch.timeseries.util.RestHandlerUtils.SUGGEST; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.ValidationException; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.SuggestForecasterParamAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.SuggestConfigParamRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to validate anomaly detector configurations. + */ +public class RestForecasterSuggestAction extends BaseRestHandler { + private static final String FORECASTER_SUGGEST_ACTION = "forecaster_suggest_action"; + + private volatile TimeValue requestTimeout; + + public RestForecasterSuggestAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = FORECAST_REQUEST_TIMEOUT.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_REQUEST_TIMEOUT, it -> requestTimeout = it); + } + + @Override + public String getName() { + return FORECASTER_SUGGEST_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, SUGGEST, TYPE) + ) + ); + } + + private Boolean suggestTypesAreAccepted(String suggestType) { + Set typesInRequest = new HashSet<>(Arrays.asList(suggestType.split(","))); + // only support interval suggest now + return (!Collections.disjoint(typesInRequest, Set.of(Forecaster.FORECAST_INTERVAL_FIELD))); + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + // we have to get the param from a subclass of BaseRestHandler. Otherwise, we cannot parse the type out of request params + String typesStr = request.param(TYPE); + + // if type param isn't blank and isn't a part of possible validation types throws exception + if (!StringUtils.isBlank(typesStr)) { + if (!suggestTypesAreAccepted(typesStr)) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError(CommonMessages.NOT_EXISTENT_SUGGEST_TYPE); + throw validationException; + } + } + + Forecaster config = parseConfig(parser); + + if (config != null) { + return channel -> { + SuggestConfigParamRequest suggestForecasterParamRequest = new SuggestConfigParamRequest( + AnalysisType.FORECAST, + config, + typesStr, + requestTimeout + ); + client + .execute( + SuggestForecasterParamAction.INSTANCE, + suggestForecasterParamRequest, + new RestToXContentListener<>(channel) + ); + }; + } else { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("fail to parse config"); + throw validationException; + } + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + private Forecaster parseConfig(XContentParser parser) throws IOException { + try { + // use default forecaster interval in case of validation exception since it can be empty + return Forecaster.parse(parser, null, null, new TimeValue(1, TimeUnit.MINUTES), null); + } catch (Exception e) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError(e.getMessage()); + throw validationException; + } + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java new file mode 100644 index 000000000..c36a925a4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.GetForecasterAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to retrieve an anomaly detector. + */ +public class RestGetForecasterAction extends BaseRestHandler { + + private static final String GET_FORECASTER_ACTION = "get_forecaster"; + + public RestGetForecasterAction() {} + + @Override + public String getName() { + return GET_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + String typesStr = request.param(TYPE); + + String rawPath = request.rawPath(); + boolean returnJob = request.paramAsBoolean("job", false); + boolean returnTask = request.paramAsBoolean("task", false); + boolean all = request.paramAsBoolean("_all", false); + GetConfigRequest getForecasterRequest = new GetConfigRequest( + forecasterId, + RestActions.parseVersion(request), + returnJob, + returnTask, + typesStr, + rawPath, + all, + RestHandlerUtils.buildEntity(request, forecasterId) + ); + + return channel -> client.execute(GetForecasterAction.INSTANCE, getForecasterRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + // Opensearch-only API. Considering users may provide entity in the search body, + // support POST as well. + + // profile API + new Route( + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE + ) + ), + // types is a profile names. See a complete list of supported profiles names in + // org.opensearch.ad.model.ProfileName. + new Route( + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE, + TYPE + ) + ), + new Route( + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE + ) + ), + // types is a profile names. See a complete list of supported profiles names in + // org.opensearch.ad.model.ProfileName. + new Route( + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE, + TYPE + ) + ), + + // get forecaster API + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java new file mode 100644 index 000000000..24a9ab037 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java @@ -0,0 +1,148 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; +import static org.opensearch.timeseries.util.RestHandlerUtils.REFRESH; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.IndexForecasterAction; +import org.opensearch.forecast.transport.IndexForecasterRequest; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.Config; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * Rest handlers to create and update forecaster. + */ +public class RestIndexForecasterAction extends AbstractForecasterAction { + private static final String INDEX_FORECASTER_ACTION = "index_forecaster_action"; + private final Logger logger = LogManager.getLogger(RestIndexForecasterAction.class); + + public RestIndexForecasterAction(Settings settings, ClusterService clusterService) { + super(settings, clusterService); + } + + @Override + public String getName() { + return INDEX_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID, Config.NO_ID); + logger.info("Forecaster {} action for forecasterId {}", request.method(), forecasterId); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Forecaster forecaster = Forecaster.parse(parser, forecasterId, null, forecastInterval, forecastWindowDelay); + + long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); + long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + WriteRequest.RefreshPolicy refreshPolicy = request.hasParam(REFRESH) + ? WriteRequest.RefreshPolicy.parse(request.param(REFRESH)) + : WriteRequest.RefreshPolicy.IMMEDIATE; + RestRequest.Method method = request.getHttpRequest().method(); + + IndexForecasterRequest indexAnomalyDetectorRequest = new IndexForecasterRequest( + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + method, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields + ); + + return channel -> client + .execute(IndexForecasterAction.INSTANCE, indexAnomalyDetectorRequest, indexForecasterResponse(channel, method)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } catch (ValidationException e) { + // convert 500 to 400 errors for validation failures + throw new OpenSearchStatusException(e.getMessage(), RestStatus.BAD_REQUEST); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + // Create + new Route(RestRequest.Method.POST, TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI), + // Update + new Route( + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } + + private RestResponseListener indexForecasterResponse(RestChannel channel, RestRequest.Method method) { + return new RestResponseListener(channel) { + @Override + public RestResponse buildResponse(IndexForecasterResponse response) throws Exception { + RestStatus restStatus = RestStatus.CREATED; + if (method == RestRequest.Method.PUT) { + restStatus = RestStatus.OK; + } + BytesRestResponse bytesRestResponse = new BytesRestResponse( + restStatus, + response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS) + ); + if (restStatus == RestStatus.CREATED) { + String location = String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI, response.getId()); + bytesRestResponse.addHeader("Location", location); + } + return bytesRestResponse; + } + }; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestRunOnceForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestRunOnceForecasterAction.java new file mode 100644 index 000000000..042e21820 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestRunOnceForecasterAction.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.RUN_ONCE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.joda.time.Instant; +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.forecast.transport.ForecastRunOnceAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to handle request to forecast. + */ +public class RestRunOnceForecasterAction extends BaseRestHandler { + + public static final String FORECASTER_ACTION = "run_forecaster_once"; + + public RestRunOnceForecasterAction() {} + + @Override + public String getName() { + return FORECASTER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + // execute forester once + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, RUN_ONCE) + ) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterId = request.param(FORECASTER_ID); + + ForecastResultRequest getRequest = new ForecastResultRequest( + forecasterId, + -1L, // will set it in ResultProcessor.onGetConfig + Instant.now().getMillis() + ); + + return channel -> client.execute(ForecastRunOnceAction.INSTANCE, getRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchForecastTasksAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchForecastTasksAction.java new file mode 100644 index 000000000..6b72e42e6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchForecastTasksAction.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.transport.SearchForecastTasksAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to search AD tasks. + */ +public class RestSearchForecastTasksAction extends AbstractForecastSearchAction { + + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI + "/tasks/_search"; + private final String SEARCH_FORECASTER_TASKS = "search_forecaster_tasks"; + + public RestSearchForecastTasksAction() { + super( + ImmutableList.of(URL_PATH), + ImmutableList.of(), + ForecastIndex.STATE.getIndexName(), + ForecastTask.class, + SearchForecastTasksAction.INSTANCE + ); + } + + @Override + public String getName() { + return SEARCH_FORECASTER_TASKS; + } + +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterAction.java new file mode 100644 index 000000000..1e5d76b7a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterAction.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.SEARCH; + +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.transport.SearchForecasterAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to search anomaly detectors. + */ +public class RestSearchForecasterAction extends AbstractForecastSearchAction { + + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI + "/" + SEARCH; + private final String SEARCH_FORECASTER_ACTION = "search_forecaster"; + + public RestSearchForecasterAction() { + super(ImmutableList.of(URL_PATH), ImmutableList.of(), CommonName.CONFIG_INDEX, Forecaster.class, SearchForecasterAction.INSTANCE); + } + + @Override + public String getName() { + return SEARCH_FORECASTER_ACTION; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterInfoAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterInfoAction.java new file mode 100644 index 000000000..16cc54ecd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchForecasterInfoAction.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.COUNT; +import static org.opensearch.timeseries.util.RestHandlerUtils.MATCH; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.SearchForecasterInfoAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.SearchConfigInfoRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +public class RestSearchForecasterInfoAction extends BaseRestHandler { + + public static final String SEARCH_FORECASTER_INFO_ACTION = "search_forecaster_info"; + + public RestSearchForecasterInfoAction() {} + + @Override + public String getName() { + return SEARCH_FORECASTER_INFO_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, org.opensearch.client.node.NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + String forecasterName = request.param("name", null); + String rawPath = request.rawPath(); + + SearchConfigInfoRequest searchForecasterInfoRequest = new SearchConfigInfoRequest(forecasterName, rawPath); + return channel -> client + .execute(SearchForecasterInfoAction.INSTANCE, searchForecasterInfoRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, COUNT) + ), + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, MATCH) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestSearchTopForecastResultAction.java b/src/main/java/org/opensearch/forecast/rest/RestSearchTopForecastResultAction.java new file mode 100644 index 000000000..49e922e9b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestSearchTopForecastResultAction.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.SearchTopForecastResultAction; +import org.opensearch.forecast.transport.SearchTopForecastResultRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * The REST handler to search top entity anomaly results for HC detectors. + */ +public class RestSearchTopForecastResultAction extends BaseRestHandler { + + private static final String URL_PATH = String + .format( + Locale.ROOT, + "%s/{%s}/%s/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + RestHandlerUtils.FORECASTER_ID, + RestHandlerUtils.RESULTS, + RestHandlerUtils.TOP_FORECASTS + ); + private final String SEARCH_TOP_FORECASTS_ACTION = "search_top_forecasts"; + + public RestSearchTopForecastResultAction() {} + + @Override + public String getName() { + return SEARCH_TOP_FORECASTS_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + // Throw error if disabled + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + // Get the typed request + SearchTopForecastResultRequest searchTopAnomalyResultRequest = getSearchTopForecastResultRequest(request); + + return channel -> client + .execute(SearchTopForecastResultAction.INSTANCE, searchTopAnomalyResultRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + + } + + private SearchTopForecastResultRequest getSearchTopForecastResultRequest(RestRequest request) throws IOException { + String forecasterId; + if (request.hasParam(RestHandlerUtils.FORECASTER_ID)) { + forecasterId = request.param(RestHandlerUtils.FORECASTER_ID); + } else { + throw new IllegalStateException(ForecastCommonMessages.FORECASTER_ID_MISSING_MSG); + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return SearchTopForecastResultRequest.parse(parser, forecasterId); + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, URL_PATH), new Route(RestRequest.Method.GET, URL_PATH)); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestStatsForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestStatsForecasterAction.java new file mode 100644 index 000000000..e4f6f5edd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestStatsForecasterAction.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.transport.StatsForecasterAction; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.rest.RestStatsAction; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * RestStatsForecasterAction consists of the REST handler to get the stats from forecasting. + */ +public class RestStatsForecasterAction extends RestStatsAction { + + private static final String STATS_FORECASTER_ACTION = "stats_forecaster"; + + /** + * Constructor + * + * @param timeSeriesStats TimeSeriesStats object + * @param nodeFilter util class to get eligible data nodes + */ + public RestStatsForecasterAction(ForecastStats timeSeriesStats, DiscoveryNodeFilterer nodeFilter) { + super(timeSeriesStats, nodeFilter); + } + + @Override + public String getName() { + return STATS_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + StatsRequest forecastStatsRequest = getRequest(request); + return channel -> client.execute(StatsForecasterAction.INSTANCE, forecastStatsRequest, new RestToXContentListener<>(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/{nodeId}/stats/"), + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/{nodeId}/stats/{stat}"), + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/stats/"), + new Route(RestRequest.Method.GET, TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI + "/stats/{stat}") + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestValidateForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestValidateForecasterAction.java new file mode 100644 index 000000000..93ff62288 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestValidateForecasterAction.java @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.VALIDATE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ValidateForecasterAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.rest.RestValidateAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.owasp.encoder.Encode; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to validate anomaly detector configurations. + */ +public class RestValidateForecasterAction extends AbstractForecasterAction { + private static final String VALIDATE_FORECASTER_ACTION = "validate_forecaster_action"; + + private RestValidateAction validateAction; + + public RestValidateForecasterAction(Settings settings, ClusterService clusterService) { + super(settings, clusterService); + this.validateAction = new RestValidateAction( + AnalysisType.FORECAST, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + requestTimeout + ); + } + + @Override + public String getName() { + return VALIDATE_FORECASTER_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, VALIDATE) + ), + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, VALIDATE, TYPE) + ) + ); + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + try { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + // we have to get the param from a subclass of BaseRestHandler. Otherwise, we cannot parse the type out of request params + String typesStr = request.param(TYPE); + + return channel -> { + try { + ValidateConfigRequest validateForecasterRequest = validateAction.prepareRequest(request, client, typesStr); + client.execute(ValidateForecasterAction.INSTANCE, validateForecasterRequest, new RestToXContentListener<>(channel)); + } catch (Exception ex) { + if (ex instanceof ValidationException) { + ValidationException forecastException = (ValidationException) ex; + ConfigValidationIssue issue = new ConfigValidationIssue( + forecastException.getAspect(), + forecastException.getType(), + forecastException.getMessage() + ); + validateAction.sendValidationParseResponse(issue, channel); + } else { + throw ex; + } + } + }; + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java new file mode 100644 index 000000000..f664ae193 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java @@ -0,0 +1,247 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest.handler; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.Locale; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public abstract class AbstractForecasterActionHandler extends + AbstractTimeSeriesActionHandler { + protected final Logger logger = LogManager.getLogger(AbstractForecasterActionHandler.class); + + public static final String EXCEEDED_MAX_HC_FORECASTERS_PREFIX_MSG = "Can't create more than %d HC forecasters."; + public static final String EXCEEDED_MAX_SINGLE_STREAM_FORECASTERS_PREFIX_MSG = "Can't create more than %d single-stream forecasters."; + public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create forecasters as no document is found in the indices: "; + public static final String DUPLICATE_FORECASTER_MSG = + "Cannot create forecasters with name [%s] as it's already used by another forecaster"; + public static final String VALIDATION_FEATURE_FAILURE = "Validation failed for feature(s) of forecaster %s"; + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client ES node client that executes actions on the local node + * @param clientUtil Forecast security client + * @param transportService ES transport service + * @param forecastIndices forecast index manager + * @param forecasterId forecaster identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param forecaster forecaster instance + * @param requestTimeout request time out configuration + * @param maxSingleStreamForecasters max single-stream forecasters allowed + * @param maxHCForecasters max HC forecasters allowed + * @param maxFeatures max features allowed per forecaster + * @param maxCategoricalFields max categorical fields allowed + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + * @param clock clock object to know when to timeout + * @param isDryRun Whether handler is dryrun or not + */ + public AbstractForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ForecastIndexManagement forecastIndices, + String forecasterId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Config forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxFeatures, + Integer maxCategoricalFields, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + ForecastTaskManager forecastTaskManager, + SearchFeatureDao searchFeatureDao, + String validationType, + boolean isDryRun, + Clock clock, + Settings settings + ) { + super( + forecaster, + forecastIndices, + isDryRun, + client, + forecasterId, + clientUtil, + user, + method, + clusterService, + xContentRegistry, + transportService, + requestTimeout, + refreshPolicy, + seqNo, + primaryTerm, + validationType, + searchFeatureDao, + maxFeatures, + maxCategoricalFields, + AnalysisType.FORECAST, + forecastTaskManager, + ForecastTaskType.RUN_ONCE_TASK_TYPES, + true, + maxSingleStreamForecasters, + maxHCForecasters, + clock, + settings + ); + } + + @Override + protected TimeSeriesException createValidationException(String msg, ValidationIssueType type) { + return new ValidationException(msg, type, ValidationAspect.FORECASTER); + } + + @Override + protected Forecaster parse(XContentParser parser, GetResponse response) throws IOException { + return Forecaster.parse(parser, response.getId(), response.getVersion()); + } + + @Override + protected String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_SINGLE_STREAM_FORECASTERS_PREFIX_MSG, getMaxSingleStreamConfigs()); + } + + @Override + protected String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_HC_FORECASTERS_PREFIX_MSG, getMaxHCConfigs()); + } + + @Override + protected String getNoDocsInUserIndexErrorMsg(String suppliedIndices) { + return String.format(Locale.ROOT, NO_DOCS_IN_USER_INDEX_MSG, suppliedIndices); + } + + @Override + protected String getDuplicateConfigErrorMsg(String name) { + return String.format(Locale.ROOT, DUPLICATE_FORECASTER_MSG, name); + } + + @Override + protected Config copyConfig(User user, Config config) { + return new Forecaster( + config.getId(), + config.getVersion(), + config.getName(), + config.getDescription(), + config.getTimeField(), + config.getIndices(), + config.getFeatureAttributes(), + config.getFilterQuery(), + config.getInterval(), + config.getWindowDelay(), + config.getShingleSize(), + config.getUiMetadata(), + config.getSchemaVersion(), + Instant.now(), + config.getCategoryFields(), + user, + config.getCustomResultIndex(), + ((Forecaster) config).getHorizon(), + config.getImputationOption(), + config.getTransformDecay() + ); + } + + @SuppressWarnings("unchecked") + @Override + protected T createIndexConfigResponse(IndexResponse indexResponse, Config config) { + return (T) new IndexForecasterResponse( + indexResponse.getId(), + indexResponse.getVersion(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + (Forecaster) config, + RestStatus.CREATED + ); + } + + @Override + protected Set getDefaultValidationType() { + return Sets.newHashSet(ValidationAspect.FORECASTER); + } + + @Override + protected String getFeatureErrorMsg(String name) { + return String.format(Locale.ROOT, VALIDATION_FEATURE_FAILURE, name); + } + + @Override + protected void validateModel(ActionListener listener) { + ForecastModelValidationActionHandler modelValidationActionHandler = new ForecastModelValidationActionHandler( + clusterService, + client, + clientUtil, + (ActionListener) listener, + (Forecaster) config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user + ); + modelValidationActionHandler.start(); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java new file mode 100644 index 000000000..c746eba79 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest.handler; + +import static org.opensearch.forecast.model.ForecastTaskType.RUN_ONCE_TASK_TYPES; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; + +import java.util.List; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.forecast.transport.StopForecasterAction; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.transport.TransportService; + +public class ForecastIndexJobActionHandler extends + IndexJobActionHandler { + + public ForecastIndexJobActionHandler( + Client client, + ForecastIndexManagement indexManagement, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager adTaskManager, + ExecuteForecastResultResponseRecorder recorder, + NodeStateManager nodeStateManager, + Settings settings + ) { + super( + client, + indexManagement, + xContentRegistry, + adTaskManager, + recorder, + ForecastResultAction.INSTANCE, + AnalysisType.FORECAST, + ForecastIndex.STATE.getIndexName(), + StopForecasterAction.INSTANCE, + nodeStateManager, + settings, + FORECAST_REQUEST_TIMEOUT + ); + } + + @Override + protected ResultRequest createResultRequest(String configID, long start, long end) { + return new ForecastResultRequest(configID, start, end); + } + + @Override + protected List getBatchConfigTaskTypes() { + return RUN_ONCE_TASK_TYPES; + } + + /** + * Stop config. + * For realtime, will set job as disabled. + * For run once, will set its task as inactive. + * + * @param configId config id + * @param historical stop historical analysis or not + * @param user user + * @param transportService transport service + * @param listener action listener + */ + @Override + public void stopConfig( + String configId, + boolean historical, + User user, + TransportService transportService, + ActionListener listener + ) { + // make sure forecaster exists + nodeStateManager.getConfig(configId, AnalysisType.FORECAST, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, ForecastTaskType.RUN_ONCE_TASK_TYPES, (task) -> { + // stop realtime forecaster job + stopJob(configId, transportService, listener); + }, transportService, true, listener); // true means reset task state as inactive/stopped state + }, listener); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/ForecastModelValidationActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/ForecastModelValidationActionHandler.java new file mode 100644 index 000000000..f03c1fdc7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/ForecastModelValidationActionHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest.handler; + +import java.time.Clock; + +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.rest.handler.ModelValidationActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ForecastModelValidationActionHandler extends ModelValidationActionHandler { + + public ForecastModelValidationActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + Forecaster config, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings, + User user + ) { + super( + clusterService, + client, + clientUtil, + listener, + config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user, + AnalysisType.FORECAST + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java new file mode 100644 index 000000000..580943441 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest.handler; + +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +/** + * process create/update anomaly detector request + * + */ +public class IndexForecasterActionHandler extends AbstractForecasterActionHandler { + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client OS node client that executes actions on the local node + * @param transportService OS transport service + * @param forecastIndices forecast index manager + * @param forecasterId forecaster identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param forecaster forecaster instance + * @param requestTimeout request time out configuration + * @param maxSingleStreamForecasters max single-stream forecasters allowed + * @param maxHCForecasters max HC forecasters allowed + * @param maxForecastFeatures max features allowed per forecaster + * @param maxCategoricalFields max number of categorical fields + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + */ + public IndexForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ForecastIndexManagement forecastIndices, + String forecasterId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Forecaster forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxForecastFeatures, + Integer maxCategoricalFields, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + ForecastTaskManager taskManager, + SearchFeatureDao searchFeatureDao, + Settings settings + ) { + super( + clusterService, + client, + clientUtil, + transportService, + forecastIndices, + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry, + user, + taskManager, + searchFeatureDao, + null, + false, + null, + settings + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/ValidateForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/ValidateForecasterActionHandler.java new file mode 100644 index 000000000..fe8549855 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/ValidateForecasterActionHandler.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest.handler; + +import java.time.Clock; + +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class ValidateForecasterActionHandler extends AbstractForecasterActionHandler { + + public ValidateForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ForecastIndexManagement forecastIndices, + Config forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxFeatures, + Integer maxCategoricalFields, + Method method, + NamedXContentRegistry xContentRegistry, + User user, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings + ) { + super( + clusterService, + client, + clientUtil, + null, + forecastIndices, + Config.NO_ID, + null, + null, + null, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxFeatures, + maxCategoricalFields, + method, + xContentRegistry, + user, + null, + searchFeatureDao, + validationType, + true, + clock, + settings + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java index 1db9bf340..b22ffc1cd 100644 --- a/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java +++ b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java @@ -27,31 +27,12 @@ public class ForecastEnabledSetting extends DynamicNumericSetting { */ public static final String FORECAST_ENABLED = "plugins.forecast.enabled"; - public static final String FORECAST_BREAKER_ENABLED = "plugins.forecast.breaker.enabled"; - - public static final String FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED = "plugins.forecast.door_keeper_in_cache.enabled";; - public static final Map> settings = unmodifiableMap(new HashMap>() { { /** * forecast enable/disable setting */ put(FORECAST_ENABLED, Setting.boolSetting(FORECAST_ENABLED, true, NodeScope, Dynamic)); - - /** - * forecast breaker enable/disable setting - */ - put(FORECAST_BREAKER_ENABLED, Setting.boolSetting(FORECAST_BREAKER_ENABLED, true, NodeScope, Dynamic)); - - /** - * We have a bloom filter placed in front of inactive entity cache to - * filter out unpopular items that are not likely to appear more - * than once. Whether this bloom filter is enabled or not. - */ - put( - FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, - Setting.boolSetting(FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic) - ); } }); @@ -73,20 +54,4 @@ public static synchronized ForecastEnabledSetting getInstance() { public static boolean isForecastEnabled() { return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_ENABLED); } - - /** - * Whether forecast circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause an forecast job to be stopped. - * @return whether forecast circuit breaker is enabled or not. - */ - public static boolean isForecastBreakerEnabled() { - return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED); - } - - /** - * If enabled, we filter out unpopular items that are not likely to appear more than once - * @return wWhether door keeper in cache is enabled or not. - */ - public static boolean isDoorKeeperInCacheEnabled() { - return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED); - } } diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java index 6b4078ad4..dc58fe599 100644 --- a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java +++ b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java @@ -77,7 +77,7 @@ public final class ForecastSettings { public static final int MAX_FORECAST_FEATURES = 1; // ====================================== - // AD Index setting + // Index setting // ====================================== public static int FORECAST_MAX_UPDATE_RETRY_TIMES = 10_000; @@ -386,4 +386,8 @@ public final class ForecastSettings { public static final Setting FORECAST_MAX_MODEL_SIZE_PER_NODE = Setting .intSetting("plugins.forecast.max_model_size_per_node", 100, 1, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + // ====================================== + // ML + // ====================================== + public static final int MINIMUM_SHINLE_SIZE = 4; } diff --git a/src/main/java/org/opensearch/forecast/stats/ForecastModelsOnNodeSupplier.java b/src/main/java/org/opensearch/forecast/stats/ForecastModelsOnNodeSupplier.java new file mode 100644 index 000000000..fd53bbbdd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/stats/ForecastModelsOnNodeSupplier.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.stats; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.timeseries.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.LAST_USED_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.MODEL_TYPE_KEY; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.constant.CommonName; + +public class ForecastModelsOnNodeSupplier implements Supplier>> { + private ForecastCacheProvider forecastCache; + private volatile int forecastNumModelsToReturn; + + /** + * Set that contains the model stats that should be exposed. + */ + public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( + Arrays + .asList( + CommonName.MODEL_ID_FIELD, + MODEL_TYPE_KEY, + CommonName.ENTITY_KEY, + LAST_USED_TIME_KEY, + LAST_CHECKPOINT_TIME_KEY, + ForecastCommonName.FORECASTER_ID_KEY + ) + ); + + /** + * Constructor + * + * @param forecastCache object that manages HC forecasters' models + * @param settings node settings accessor + * @param clusterService Cluster service accessor + */ + public ForecastModelsOnNodeSupplier(ForecastCacheProvider forecastCache, Settings settings, ClusterService clusterService) { + this.forecastCache = forecastCache; + this.forecastNumModelsToReturn = FORECAST_MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_MAX_MODEL_SIZE_PER_NODE, it -> this.forecastNumModelsToReturn = it); + } + + @Override + public List> get() { + Stream> forecastStream = forecastCache + .get() + .getAllModels() + .stream() + .limit(forecastNumModelsToReturn) + .map( + modelState -> modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + + return forecastStream.collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/forecast/stats/ForecastStats.java b/src/main/java/org/opensearch/forecast/stats/ForecastStats.java new file mode 100644 index 000000000..197043cde --- /dev/null +++ b/src/main/java/org/opensearch/forecast/stats/ForecastStats.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.stats; + +import java.util.Map; + +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.stats.TimeSeriesStat; + +public class ForecastStats extends Stats { + + public ForecastStats(Map> stats) { + super(stats); + } + +} diff --git a/src/main/java/org/opensearch/forecast/stats/suppliers/ForecastModelsOnNodeCountSupplier.java b/src/main/java/org/opensearch/forecast/stats/suppliers/ForecastModelsOnNodeCountSupplier.java new file mode 100644 index 000000000..7a8b5283d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/stats/suppliers/ForecastModelsOnNodeCountSupplier.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.stats.suppliers; + +import java.util.function.Supplier; + +import org.opensearch.forecast.caching.ForecastCacheProvider; + +/** + * ModelsOnNodeCountSupplier provides the number of models a node contains + */ +public class ForecastModelsOnNodeCountSupplier implements Supplier { + private ForecastCacheProvider forecastCache; + + /** + * Constructor + * + * @param forecastCache object that manages models + */ + public ForecastModelsOnNodeCountSupplier(ForecastCacheProvider forecastCache) { + this.forecastCache = forecastCache; + } + + @Override + public Long get() { + return forecastCache.get().getAllModels().stream().count(); + } +} diff --git a/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java b/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java new file mode 100644 index 000000000..bc2c63002 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java @@ -0,0 +1,521 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.task; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FORECASTER_IS_RUNNING; +import static org.opensearch.forecast.indices.ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN; +import static org.opensearch.forecast.model.ForecastTask.FORECASTER_ID_FIELD; +import static org.opensearch.forecast.model.ForecastTaskType.REALTIME_TASK_TYPES; +import static org.opensearch.forecast.settings.ForecastSettings.DELETE_FORECAST_RESULT_WHEN_DELETE_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_OLD_TASK_DOCS_PER_FORECASTER; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME; +import static org.opensearch.timeseries.model.TimeSeriesTask.TASK_ID_FIELD; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.function.ResponseTransformer; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportService; + +public class ForecastTaskManager extends + TaskManager { + private final Logger logger = LogManager.getLogger(ForecastTaskManager.class); + + public ForecastTaskManager( + TaskCacheManager forecastTaskCacheManager, + Client client, + NamedXContentRegistry xContentRegistry, + ForecastIndexManagement forecastIndices, + ClusterService clusterService, + Settings settings, + ThreadPool threadPool, + NodeStateManager nodeStateManager + ) { + super( + forecastTaskCacheManager, + clusterService, + client, + ForecastIndex.STATE.getIndexName(), + ForecastTaskType.REALTIME_TASK_TYPES, + Collections.emptyList(), + ForecastTaskType.RUN_ONCE_TASK_TYPES, + forecastIndices, + nodeStateManager, + AnalysisType.FORECAST, + xContentRegistry, + FORECASTER_ID_FIELD, + MAX_OLD_TASK_DOCS_PER_FORECASTER, + settings, + threadPool, + ALL_FORECAST_RESULTS_INDEX_PATTERN, + FORECAST_THREAD_POOL_NAME, + DELETE_FORECAST_RESULT_WHEN_DELETE_FORECASTER, + TaskState.INACTIVE + ); + } + + /** + * Init realtime task cache Realtime forecast depending on job scheduler to choose node (job coordinating node) + * to run forecast job. Nodes have primary or replica shard of the job index are candidate to run forecast job. + * Job scheduler will build hash ring on these candidate nodes and choose one to run forecast job. + * If forecast job index shard relocated, for example new node added into cluster, then job scheduler will + * rebuild hash ring and may choose different node to run forecast job. So we need to init realtime task cache + * on new forecast job coordinating node. + * + * If realtime task cache inited for the first time on this node, listener will return true; otherwise + * listener will return false. + * + * We don't clean up realtime task cache on old coordinating node as HourlyCron will clear cache on old coordinating node. + * + * @param forecasterId forecaster id + * @param forecaster forecaster + * @param transportService transport service + * @param listener listener + */ + @Override + public void initRealtimeTaskCacheAndCleanupStaleCache( + String forecasterId, + Config forecaster, + TransportService transportService, + ActionListener listener + ) { + try { + if (taskCacheManager.getRealtimeTaskCache(forecasterId) != null) { + listener.onResponse(false); + return; + } + + getAndExecuteOnLatestConfigLevelTask(forecasterId, REALTIME_TASK_TYPES, (forecastTaskOptional) -> { + if (forecastTaskOptional.isEmpty()) { + logger.debug("Can't find realtime task for forecaster {}, init realtime task cache directly", forecasterId); + ExecutorFunction function = () -> createNewTask( + forecaster, + null, + false, + forecaster.getUser(), + clusterService.localNode().getId(), + TaskState.CREATED, + ActionListener.wrap(r -> { + logger.info("Recreate realtime task successfully for forecaster {}", forecasterId); + taskCacheManager.initRealtimeTaskCache(forecasterId, forecaster.getIntervalInMilliseconds()); + listener.onResponse(true); + }, e -> { + logger.error("Failed to recreate realtime task for forecaster " + forecasterId, e); + listener.onFailure(e); + }) + ); + recreateRealtimeTaskBeforeExecuting(function, listener); + return; + } + + logger.info("Init realtime task cache for forecaster {}", forecasterId); + taskCacheManager.initRealtimeTaskCache(forecasterId, forecaster.getIntervalInMilliseconds()); + listener.onResponse(true); + }, transportService, false, listener); + } catch (Exception e) { + logger.error("Failed to init realtime task cache for " + forecasterId, e); + listener.onFailure(e); + } + } + + /** + * Update forecast task with specific fields. + * + * @param taskId forecast task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateForecastTask(String taskId, Map updatedFields) { + updateForecastTask(taskId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + logger.debug("Updated forecast task successfully: {}, task id: {}", response.status(), taskId); + } else { + logger.error("Failed to update forecast task {}, status: {}", taskId, response.status()); + } + }, e -> { logger.error("Failed to update task: " + taskId, e); })); + } + + /** + * Update forecast task for specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateForecastTask(String taskId, Map updatedFields, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(ForecastIndex.STATE.getIndexName(), taskId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updatedContent.put(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, listener); + } + + private void recreateRealtimeTaskBeforeExecuting(ExecutorFunction function, ActionListener listener) { + if (indexManagement.doesStateIndexExist()) { + function.execute(); + } else { + // If forecast state index doesn't exist, create index and execute function. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", ForecastIndex.STATE.getIndexName()); + function.execute(); + } else { + String error = String + .format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, ForecastIndex.STATE.getIndexName()); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + function.execute(); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } + + /** + * Poll deleted detector task from cache and delete its child tasks and AD results. + */ + @Override + public void cleanChildTasksAndResultsOfDeletedTask() { + if (!taskCacheManager.hasDeletedTask()) { + return; + } + threadPool.schedule(() -> { + String taskId = taskCacheManager.pollDeletedTask(); + if (taskId == null) { + return; + } + DeleteByQueryRequest deleteForecastResultsRequest = new DeleteByQueryRequest(ALL_FORECAST_RESULTS_INDEX_PATTERN); + deleteForecastResultsRequest.setQuery(new TermsQueryBuilder(TASK_ID_FIELD, taskId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteForecastResultsRequest, ActionListener.wrap(res -> { + logger.debug("Successfully deleted forecast results of task " + taskId); + DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(ForecastIndex.STATE.getIndexName()); + deleteChildTasksRequest.setQuery(new TermsQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, taskId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { + logger.debug("Successfully deleted child tasks of task " + taskId); + cleanChildTasksAndResultsOfDeletedTask(); + }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); + }, ex -> { logger.error("Failed to delete forecast results for task " + taskId, ex); })); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); + } + + @Override + public void startHistorical( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + // TODO Auto-generated method stub + + } + + @Override + protected TaskType getTaskType(Config config, DateRange dateRange, boolean runOnce) { + if (runOnce) { + return config.isHighCardinality() + ? ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER + : ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM; + } else { + return config.isHighCardinality() + ? ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER + : ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM; + } + } + + @Override + protected void createNewTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + String coordinatingNode, + TaskState initialState, + ActionListener listener + ) { + String userName = user == null ? null : user.getName(); + Instant now = Instant.now(); + String taskType = getTaskType(config, dateRange, runOnce).name(); + ForecastTask.Builder forecastTaskBuilder = new ForecastTask.Builder() + .configId(config.getId()) + .forecaster((Forecaster) config) + .isLatest(true) + .taskType(taskType) + .executionStartTime(now) + .state(initialState.name()) + .lastUpdateTime(now) + .startedBy(userName) + .coordinatingNode(coordinatingNode) + .user(user); + + ResponseTransformer responseTransformer; + + final ForecastTask forecastTask; + + // used for run once + if (initialState == TaskState.INIT_TEST) { + forecastTask = forecastTaskBuilder.build(); + responseTransformer = (indexResponse) -> (T) forecastTask; + } else { + forecastTask = forecastTaskBuilder.taskProgress(0.0f).initProgress(0.0f).dateRange(dateRange).build(); + // used for real time + responseTransformer = (indexResponse) -> (T) new JobResponse(indexResponse.getId()); + } + + createTaskDirectly( + forecastTask, + r -> onIndexConfigTaskResponse( + r, + forecastTask, + (response, delegatedListener) -> cleanOldConfigTaskDocs(response, forecastTask, responseTransformer, delegatedListener), + listener + ), + listener + ); + + } + + @Override + public void cleanConfigCache( + TimeSeriesTask task, + TransportService transportService, + ExecutorFunction function, + ActionListener listener + ) { + // no op for forecaster as we rely on state ttl to auto clean it + // only execute function + function.execute(); + } + + @Override + protected boolean isHistoricalHCTask(TimeSeriesTask task) { + // we have no backtesting + return false; + } + + @Override + protected void onIndexConfigTaskResponse( + IndexResponse response, + ForecastTask forecastTask, + BiConsumer> function, + ActionListener listener + ) { + if (response == null || response.getResult() != CREATED) { + String errorMsg = ExceptionUtil.getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + forecastTask.setTaskId(response.getId()); + ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + handleTaskException(forecastTask, e); + if (e instanceof DuplicateTaskException) { + listener.onFailure(new OpenSearchStatusException(FORECASTER_IS_RUNNING, RestStatus.BAD_REQUEST)); + } else { + // TODO: For historical forecast task, what to do if any exception happened? + // For realtime forecast, task cache will be inited when realtime job starts, check + // ForecastTaskManager#initRealtimeTaskCache for details. Here the + // realtime task cache not inited yet when create forecast task, so no need to cleanup. + listener.onFailure(e); + } + }); + // TODO: what to do if this is a historical task? + if (function != null) { + function.accept(response, delegatedListener); + } + } + + @Override + protected void runBatchResultAction( + IndexResponse response, + ForecastTask tsTask, + ResponseTransformer responseTransformer, + ActionListener listener + ) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Forecast does not support back testing yet."); + } + + @Override + protected BiCheckedFunction getTaskParser() { + return ForecastTask::parse; + } + + @Override + public void createRunOnceTaskAndCleanupStaleTasks( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ) { + ForecastTaskType taskType = config.isHighCardinality() + ? ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER + : ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM; + + try { + + if (indexManagement.doesStateIndexExist()) { + // If state index exist, check if latest task is running + getAndExecuteOnLatestConfigLevelTask(config.getId(), Arrays.asList(taskType), (task) -> { + if (!task.isPresent() || task.get().isDone()) { + updateLatestFlagOfOldTasksAndCreateNewTask(config, null, true, config.getUser(), TaskState.INIT_TEST, listener); + } else { + listener.onFailure(new OpenSearchStatusException("run once is on-going", RestStatus.BAD_REQUEST)); + } + }, transportService, true, listener); + } else { + // If state index doesn't exist, create index and execute forecast. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", stateIndex); + updateLatestFlagOfOldTasksAndCreateNewTask(config, null, true, config.getUser(), TaskState.INIT_TEST, listener); + } else { + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, stateIndex); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + updateLatestFlagOfOldTasksAndCreateNewTask(config, null, true, config.getUser(), TaskState.INIT_TEST, listener); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } catch (Exception e) { + logger.error("Failed to start detector " + config.getId(), e); + listener.onFailure(e); + } + } + + @Override + public List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce) { + if (runOnce) { + return ForecastTaskType.RUN_ONCE_TASK_TYPES; + } else { + return ForecastTaskType.REALTIME_TASK_TYPES; + } + } + + private void resetRunOnceConfigTaskState( + List runOnceTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + if (ParseUtils.isNullOrEmpty(runOnceTasks)) { + function.execute(); + return; + } + ForecastTask forecastTask = (ForecastTask) runOnceTasks.get(0); + resetTaskStateAsStopped(forecastTask, function, transportService, listener); + } + + /** + * Reset latest config task state. Will reset both historical and realtime tasks. + * [Important!] Make sure listener returns in function + * + * @param tasks tasks + * @param function consumer function + * @param transportService transport service + * @param listener action listener + * @param response type of action listener + */ + @Override + protected void resetLatestConfigTaskState( + List tasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ) { + List runningRealtimeTasks = new ArrayList<>(); + List runningRunOnceTasks = new ArrayList<>(); + + for (TimeSeriesTask task : tasks) { + if (!task.isHistoricalEntityTask() && !task.isDone()) { + if (task.isRealTimeTask()) { + runningRealtimeTasks.add(task); + } else if (task.isRunOnceTask()) { + runningRunOnceTasks.add(task); + } + } + } + + resetRunOnceConfigTaskState( + runningRunOnceTasks, + () -> resetRealtimeConfigTaskState(runningRealtimeTasks, () -> function.accept(tasks), transportService, listener), + transportService, + listener + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/BuildInQuery.java b/src/main/java/org/opensearch/forecast/transport/BuildInQuery.java new file mode 100644 index 000000000..c36c930c6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/BuildInQuery.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +public enum BuildInQuery { + MIN_CONFIDENCE_INTERVAL_WIDTH, + MAX_CONFIDENCE_INTERVAL_WIDTH, + MIN_VALUE_WITHIN_THE_HORIZON, + MAX_VALUE_WITHIN_THE_HORIZON, + DISTANCE_TO_THRESHOLD_VALUE +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java new file mode 100644 index 000000000..eab816842 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.DeleteModelResponse; + +public class DeleteForecastModelAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; + public static final DeleteForecastModelAction INSTANCE = new DeleteForecastModelAction(); + + private DeleteForecastModelAction() { + super(NAME, DeleteModelResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java new file mode 100644 index 000000000..fad3bdd12 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseDeleteModelTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class DeleteForecastModelTransportAction extends + BaseDeleteModelTransportAction { + + @Inject + public DeleteForecastModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ForecastCacheProvider cache, + TaskCacheManager taskCacheManager, + ForecastColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + cache, + taskCacheManager, + coldStarter, + DeleteForecastModelAction.NAME + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterAction.java new file mode 100644 index 000000000..c18bc2327 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class DeleteForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/delete"; + public static final DeleteForecasterAction INSTANCE = new DeleteForecasterAction(); + + private DeleteForecasterAction() { + super(NAME, DeleteResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterTransportAction.java new file mode 100644 index 000000000..bf6094934 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecasterTransportAction.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseDeleteConfigTransportAction; +import org.opensearch.transport.TransportService; + +public class DeleteForecasterTransportAction extends + BaseDeleteConfigTransportAction { + + @Inject + public DeleteForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + NodeStateManager nodeStateManager, + ForecastTaskManager taskManager + ) { + super( + transportService, + actionFilters, + client, + clusterService, + settings, + xContentRegistry, + nodeStateManager, + taskManager, + DeleteForecasterAction.NAME, + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + AnalysisType.FORECAST, + ForecastIndex.STATE.getIndexName(), + Forecaster.class, + ForecastTaskType.RUN_ONCE_TASK_TYPES + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java new file mode 100644 index 000000000..77eec3d51 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class EntityForecastResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; + public static final EntityForecastResultAction INSTANCE = new EntityForecastResultAction(); + + private EntityForecastResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java new file mode 100644 index 000000000..d638b3bae --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java @@ -0,0 +1,174 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.forecast.transport; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdEntityWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.EntityResultProcessor; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * Entry-point for HC forecast workflow. We have created multiple queues for coordinating + * the workflow. The overrall workflow is: + * 1. We store as many frequently used entity models in a cache as allowed by the + * memory limit (by default 10% heap). If an entity feature is a hit, we use the in-memory model + * to forecast and record results using the result write queue. + * 2. If an entity feature is a miss, we check if there is free memory or any other + * entity's model can be evacuated. An in-memory entity's frequency may be lower + * compared to the cache miss entity. If that's the case, we replace the lower + * frequency entity's model with the higher frequency entity's model. To load the + * higher frequency entity's model, we first check if a model exists on disk by + * sending a checkpoint read queue request. If there is a checkpoint, we load it + * to memory, perform forecast, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the checkpoint + * write queue. + * 3. We also have the cold entity queue configured for cold entities, and the model + * training and inference are connected by serial juxtaposition to limit resource usage. + */ +public class EntityForecastResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityForecastResultTransportAction.class); + private CircuitBreakerService circuitBreakerService; + private CacheProvider cache; + private final NodeStateManager stateManager; + private ThreadPool threadPool; + private EntityResultProcessor intervalDataProcessor; + + private final ForecastCacheProvider entityCache; + private final ForecastModelManager manager; + private final ForecastStats timeSeriesStats; + private final ForecastColdStartWorker entityColdStartWorker; + private final ForecastCheckpointReadWorker checkpointReadQueue; + private final ForecastColdEntityWorker coldEntityQueue; + private final ForecastSaveResultStrategy forecastSaveResultStategy; + + @Inject + public EntityForecastResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ForecastModelManager manager, + CircuitBreakerService adCircuitBreakerService, + ForecastCacheProvider entityCache, + NodeStateManager stateManager, + ForecastIndexManagement indexUtil, + ForecastResultWriteWorker resultWriteQueue, + ForecastCheckpointReadWorker checkpointReadQueue, + ForecastColdEntityWorker coldEntityQueue, + ThreadPool threadPool, + ForecastColdStartWorker entityColdStartWorker, + ForecastStats timeSeriesStats, + ForecastSaveResultStrategy forecastSaveResultStategy + ) { + super(EntityForecastResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); + this.circuitBreakerService = adCircuitBreakerService; + this.cache = entityCache; + this.stateManager = stateManager; + this.threadPool = threadPool; + this.intervalDataProcessor = null; + this.entityCache = entityCache; + this.manager = manager; + this.timeSeriesStats = timeSeriesStats; + this.entityColdStartWorker = entityColdStartWorker; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.forecastSaveResultStategy = forecastSaveResultStategy; + } + + @Override + protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { + if (circuitBreakerService.isOpen()) { + threadPool + .executor(TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String forecasterId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(forecasterId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", forecasterId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, forecasterId); + } + + intervalDataProcessor = new EntityResultProcessor<>( + entityCache, + manager, + timeSeriesStats, + entityColdStartWorker, + checkpointReadQueue, + coldEntityQueue, + forecastSaveResultStategy, + StatNames.FORECAST_MODEL_CORRUTPION_COUNT + ); + + stateManager + .getConfig( + forecasterId, + request.getAnalysisType(), + intervalDataProcessor.onGetConfig(listener, forecasterId, request, previousException, request.getAnalysisType()) + ); + } catch (Exception exception) { + LOG.error("fail to get entity's forecasts", exception); + listener.onFailure(exception); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileAction.java new file mode 100644 index 000000000..de4b48ef4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.EntityProfileResponse; + +public class ForecastEntityProfileAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "forecasters/profile/entity"; + public static final ForecastEntityProfileAction INSTANCE = new ForecastEntityProfileAction(); + + private ForecastEntityProfileAction() { + super(NAME, EntityProfileResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileTransportAction.java new file mode 100644 index 000000000..6fe726c4e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastEntityProfileTransportAction.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.transport.BaseEntityProfileTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * Transport action to get entity profile. + */ +public class ForecastEntityProfileTransportAction extends + BaseEntityProfileTransportAction { + + @Inject + public ForecastEntityProfileTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + HashRing hashRing, + ClusterService clusterService, + ForecastCacheProvider cacheProvider + ) { + super( + actionFilters, + transportService, + settings, + hashRing, + clusterService, + cacheProvider, + ForecastEntityProfileAction.NAME, + ForecastSettings.FORECAST_REQUEST_TIMEOUT + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastProfileAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastProfileAction.java new file mode 100644 index 000000000..35595a76f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastProfileAction.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.ProfileResponse; + +/** + * Profile transport action + */ +public class ForecastProfileAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecasters/profile"; + public static final ForecastProfileAction INSTANCE = new ForecastProfileAction(); + + /** + * Constructor + */ + private ForecastProfileAction() { + super(NAME, ProfileResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastProfileTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastProfileTransportAction.java new file mode 100644 index 000000000..87c7ccdba --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastProfileTransportAction.java @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.BaseProfileTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * This class contains the logic to extract the stats from the nodes + */ +public class ForecastProfileTransportAction extends BaseProfileTransportAction { + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param cacheProvider cache provider + * @param settings Node settings accessor + */ + @Inject + public ForecastProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ForecastCacheProvider cacheProvider, + Settings settings + ) { + super( + ForecastProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + cacheProvider, + settings, + FORECAST_MAX_MODEL_SIZE_PER_NODE + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java new file mode 100644 index 000000000..ef9178a02 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastResultAction extends ActionType { + // External Action which used for public facing RestAPIs or actions we need to assume cx's role. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/run"; + public static final ForecastResultAction INSTANCE = new ForecastResultAction(); + + private ForecastResultAction() { + super(NAME, ForecastResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java new file mode 100644 index 000000000..6394636b3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.transport.TransportRequestOptions; + +public class ForecastResultBulkAction extends ActionType { + + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; + public static final ForecastResultBulkAction INSTANCE = new ForecastResultBulkAction(); + + private ForecastResultBulkAction() { + super(NAME, ResultBulkResponse::new); + } + + @Override + public TransportRequestOptions transportOptions(Settings settings) { + return TransportRequestOptions.builder().withType(TransportRequestOptions.Type.BULK).build(); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java new file mode 100644 index 000000000..730275b4d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.timeseries.transport.ResultBulkRequest; + +public class ForecastResultBulkRequest extends ResultBulkRequest { + + public ForecastResultBulkRequest() { + super(); + } + + public ForecastResultBulkRequest(StreamInput in) throws IOException { + super(in, ForecastResultWriteRequest::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java new file mode 100644 index 000000000..95422a98a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT; + +import java.util.List; + +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.index.IndexingPressure; +import org.opensearch.timeseries.transport.ResultBulkTransportAction; +import org.opensearch.transport.TransportService; + +public class ForecastResultBulkTransportAction extends + ResultBulkTransportAction { + + @Inject + public ForecastResultBulkTransportAction( + TransportService transportService, + ActionFilters actionFilters, + IndexingPressure indexingPressure, + Settings settings, + ClusterService clusterService, + Client client + ) { + super( + ForecastResultBulkAction.NAME, + transportService, + actionFilters, + indexingPressure, + settings, + client, + FORECAST_INDEX_PRESSURE_SOFT_LIMIT.get(settings), + FORECAST_INDEX_PRESSURE_HARD_LIMIT.get(settings), + ForecastIndex.RESULT.getIndexName(), + ForecastResultBulkRequest::new + ); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); + } + + @Override + protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ForecastResultBulkRequest request) { + BulkRequest bulkRequest = new BulkRequest(); + List results = request.getAnomalyResults(); + + if (indexingPressurePercent <= softLimit) { + for (ForecastResultWriteRequest resultWriteRequest : results) { + addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getResultIndex()); + } + } else if (indexingPressurePercent <= hardLimit) { + // exceed soft limit (60%) but smaller than hard limit (90%) + float acceptProbability = 1 - indexingPressurePercent; + for (ForecastResultWriteRequest resultWriteRequest : results) { + ForecastResult result = resultWriteRequest.getResult(); + if (random.nextFloat() < acceptProbability) { + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); + } + } + } else { + // if exceeding hard limit, only index error result + for (ForecastResultWriteRequest resultWriteRequest : results) { + ForecastResult result = resultWriteRequest.getResult(); + if (result.isHighPriority()) { + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); + } + } + } + + return bulkRequest; + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java new file mode 100644 index 000000000..36521c53d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultProcessor.java @@ -0,0 +1,198 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_ENTITIES_PER_INTERVAL; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_PAGE_SIZE; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.transport.SingleStreamResultRequest; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastResultProcessor extends + ResultProcessor { + + private static final Logger LOG = LogManager.getLogger(ForecastResultProcessor.class); + + public ForecastResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + ForecastStats timeSeriesStats, + ForecastTaskManager realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager, + AnalysisType analysisType, + boolean runOnce + ) { + super( + requestTimeoutSetting, + intervalRatioForRequests, + entityResultAction, + hcRequestCountStat, + settings, + clusterService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + hashRing, + nodeStateManager, + transportService, + timeSeriesStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + transportResultResponseClazz, + featureManager, + FORECAST_MAX_ENTITIES_PER_INTERVAL, + FORECAST_PAGE_SIZE, + analysisType, + runOnce + ); + } + + @Override + protected ActionListener onFeatureResponseForSingleStreamConfig( + String forecasterId, + Config config, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime, + String taskId + ) { + return ActionListener.wrap(featureOptional -> { + Optional previousException = nodeStateManager.fetchExceptionAndClear(forecasterId); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous forecast exception of [{}]", forecasterId), exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + } + + Forecaster forecaster = (Forecaster) config; + + if (featureOptional.getUnprocessedFeatures().isEmpty()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, forecasterId); + listener + .onResponse(createResultResponse(new ArrayList(), "No data in current window", null, null, false, taskId)); + return; + } + + final AtomicReference failure = new AtomicReference(); + + LOG.info("Sending forecast single stream request to {} for model {}", rcfNode.getId(), rcfModelId); + + transportService + .sendRequest( + rcfNode, + ForecastSingleStreamResultAction.NAME, + new SingleStreamResultRequest( + forecasterId, + rcfModelId, + dataStartTime, + dataEndTime, + featureOptional.getUnprocessedFeatures().get(), + taskId + ), + option, + new ActionListenerResponseHandler<>( + new ErrorResponseListener(rcfNode.getId(), forecasterId, failure), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else if (!featureOptional.getUnprocessedFeatures().isPresent()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, forecasterId); + listener + .onResponse(createResultResponse(new ArrayList(), "No data in current window", null, null, false, taskId)); + } else { + listener + .onResponse( + createResultResponse(new ArrayList(), null, null, forecaster.getIntervalInMinutes(), true, taskId) + ); + } + }, exception -> { handleQueryFailure(exception, listener, forecasterId); }); + } + + @Override + protected ForecastResultResponse createResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ) { + return new ForecastResultResponse(features, error, rcfTotalUpdates, configInterval, isHC, taskId); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java new file mode 100644 index 000000000..074e975e4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ForecastResultRequest extends ResultRequest { + + public ForecastResultRequest(StreamInput in) throws IOException { + super(in); + in.readEnum(AnalysisType.class); + } + + public ForecastResultRequest(String forecastID, long start, long end) { + super(forecastID, start, end); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(configId)) { + validationException = addValidationError(ForecastCommonMessages.FORECASTER_ID_MISSING_MSG, validationException); + } + // at least end time should be set + if (end <= 0) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", ForecastCommonMessages.INVALID_TIMESTAMP_ERR_MSG, start, end), + validationException + ); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ForecastCommonName.ID_JSON_KEY, configId); + builder.field(CommonName.START_JSON_KEY, start); + builder.field(CommonName.END_JSON_KEY, end); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java new file mode 100644 index 000000000..b1c7a8b47 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java @@ -0,0 +1,221 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.transport.ResultResponse; + +public class ForecastResultResponse extends ResultResponse { + public static final String DATA_QUALITY_JSON_KEY = "dataQuality"; + public static final String ERROR_JSON_KEY = "error"; + public static final String FEATURES_JSON_KEY = "features"; + public static final String FEATURE_VALUE_JSON_KEY = "value"; + public static final String RCF_TOTAL_UPDATES_JSON_KEY = "rcfTotalUpdates"; + public static final String FORECASTER_INTERVAL_IN_MINUTES_JSON_KEY = "forecasterIntervalInMinutes"; + public static final String FORECAST_VALUES_JSON_KEY = "forecastValues"; + public static final String FORECAST_UPPERS_JSON_KEY = "forecastUppers"; + public static final String FORECAST_LOWERS_JSON_KEY = "forecastLowers"; + + private Double dataQuality; + private float[] forecastsValues; + private float[] forecastsUppers; + private float[] forecastsLowers; + + // used when returning an error/exception or empty result + public ForecastResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long forecasterIntervalInMinutes, + Boolean isHCForecaster, + String taskId + ) { + this(Double.NaN, features, error, rcfTotalUpdates, forecasterIntervalInMinutes, isHCForecaster, null, null, null, taskId); + } + + public ForecastResultResponse( + Double confidence, + List features, + String error, + Long rcfTotalUpdates, + Long forecasterIntervalInMinutes, + Boolean isHCForecaster, + float[] forecastsValues, + float[] forecastsUppers, + float[] forecastsLowers, + String taskId + ) { + super(features, error, rcfTotalUpdates, forecasterIntervalInMinutes, isHCForecaster, taskId); + this.dataQuality = confidence; + this.forecastsValues = forecastsValues; + this.forecastsUppers = forecastsUppers; + this.forecastsLowers = forecastsLowers; + this.taskId = taskId; + } + + public ForecastResultResponse(StreamInput in) throws IOException { + super(in); + dataQuality = in.readDouble(); + int size = in.readVInt(); + features = new ArrayList(); + for (int i = 0; i < size; i++) { + features.add(new FeatureData(in)); + } + error = in.readOptionalString(); + rcfTotalUpdates = in.readOptionalLong(); + configIntervalInMinutes = in.readOptionalLong(); + isHC = in.readOptionalBoolean(); + + if (in.readBoolean()) { + forecastsValues = in.readFloatArray(); + forecastsUppers = in.readFloatArray(); + forecastsLowers = in.readFloatArray(); + } else { + forecastsValues = null; + forecastsUppers = null; + forecastsLowers = null; + } + taskId = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(dataQuality); + out.writeVInt(features.size()); + for (FeatureData feature : features) { + feature.writeTo(out); + } + out.writeOptionalString(error); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalLong(configIntervalInMinutes); + out.writeOptionalBoolean(isHC); + + if (forecastsValues != null) { + if (forecastsUppers == null || forecastsLowers == null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "null value: forecastsUppers: %s, forecastsLowers: %s", forecastsUppers, forecastsLowers) + ); + } + out.writeBoolean(true); + out.writeFloatArray(forecastsValues); + out.writeFloatArray(forecastsUppers); + out.writeFloatArray(forecastsLowers); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(taskId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (dataQuality != null && !dataQuality.equals(Double.NaN)) { + builder.field(DATA_QUALITY_JSON_KEY, dataQuality); + } + if (error != null) { + builder.field(ERROR_JSON_KEY, error); + } + if (features != null && features.size() > 0) { + builder.startArray(FEATURES_JSON_KEY); + for (FeatureData feature : features) { + feature.toXContent(builder, params); + } + builder.endArray(); + } + if (rcfTotalUpdates != null) { + builder.field(RCF_TOTAL_UPDATES_JSON_KEY, rcfTotalUpdates); + } + if (forecastsValues != null) { + builder.field(FORECAST_VALUES_JSON_KEY, forecastsValues); + } + if (forecastsUppers != null) { + builder.field(FORECAST_UPPERS_JSON_KEY, forecastsUppers); + } + if (forecastsLowers != null) { + builder.field(FORECAST_LOWERS_JSON_KEY, forecastsLowers); + } + if (taskId != null) { + builder.field(CommonName.TASK_ID_FIELD, taskId); + } + // don't show interval as we only need to access it in memory to compute init estimated time remaining + builder.endObject(); + return builder; + } + + /** + * + * Convert ForecastResultResponse to ForecastResult + * + * @param forecastId Forecaster Id + * @param dataStartInstant data start time + * @param dataEndInstant data end time + * @param executionStartInstant execution start time + * @param executionEndInstant execution end time + * @param schemaVersion Schema version + * @param user Detector author + * @param error Error + * @return converted ForecastResult + */ + @Override + public List toIndexableResults( + String forecastId, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + Integer schemaVersion, + User user, + String error + ) { + // Forecast interval in milliseconds + long forecasterIntervalMilli = Duration.between(dataStartInstant, dataEndInstant).toMillis(); + return ForecastResult + .fromRawRCFCasterResult( + forecastId, + forecasterIntervalMilli, + dataQuality, + features, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + Optional.empty(), + user, + schemaVersion, + null, // single-stream real-time has no model id + forecastsValues, + forecastsUppers, + forecastsLowers, + taskId // real time results have no task id + ); + } + + @Override + public boolean shouldSave() { + return super.shouldSave() || (forecastsValues != null && forecastsValues.length > 0); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java new file mode 100644 index 000000000..1db61e5d9 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java @@ -0,0 +1,188 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ForecastResultTransportAction.class); + private ResultProcessor resultProcessor; + private final Client client; + private CircuitBreakerService circuitBreakerService; + // Cache HC forecaster id. This is used to count HC failure stats. We can tell a forecaster + // is HC or not by checking if forecaster id exists in this field or not. Will add + // forecaster id to this field when start to run realtime detection and remove forecaster + // id once realtime detection done. + private final Set hcForecasters; + private final ForecastStats forecastStats; + private final NodeStateManager nodeStateManager; + private final Settings settings; + private final ClusterService clusterService; + private final ThreadPool threadPool; + private final HashRing hashRing; + private final TransportService transportService; + private final ForecastTaskManager realTimeTaskManager; + private final NamedXContentRegistry xContentRegistry; + private final SecurityClientUtil clientUtil; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final FeatureManager featureManager; + + @Inject + public ForecastResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + Client client, + SecurityClientUtil clientUtil, + NodeStateManager nodeStateManager, + FeatureManager featureManager, + ForecastModelManager modelManager, + HashRing hashRing, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver, + CircuitBreakerService circuitBreakerService, + ForecastStats forecastStats, + ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager realTimeTaskManager + ) { + super(ForecastResultAction.NAME, transportService, actionFilters, ForecastResultRequest::new); + + this.settings = settings; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.hashRing = hashRing; + this.transportService = transportService; + this.realTimeTaskManager = realTimeTaskManager; + this.xContentRegistry = xContentRegistry; + this.clientUtil = clientUtil; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.featureManager = featureManager; + + this.client = client; + this.circuitBreakerService = circuitBreakerService; + this.hcForecasters = new HashSet<>(); + this.forecastStats = forecastStats; + this.nodeStateManager = nodeStateManager; + + this.resultProcessor = null; + } + + @Override + protected void doExecute(Task task, ForecastResultRequest request, ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String forecastID = request.getConfigId(); + ActionListener original = listener; + listener = ActionListener.wrap(r -> { + hcForecasters.remove(forecastID); + original.onResponse(r); + }, e -> { + // If exception is TimeSeriesException and it should not be counted in stats, + // we will not count it in failure stats. + if (!(e instanceof TimeSeriesException) || ((TimeSeriesException) e).isCountedInStats()) { + forecastStats.getStat(StatNames.FORECAST_EXECUTE_FAIL_COUNT.getName()).increment(); + if (hcForecasters.contains(forecastID)) { + forecastStats.getStat(StatNames.FORECAST_HC_EXECUTE_FAIL_COUNT.getName()).increment(); + } + } + hcForecasters.remove(forecastID); + original.onFailure(e); + }); + + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new EndRunException(forecastID, ForecastCommonMessages.DISABLED_ERR_MSG, true).countedInStats(false); + } + + forecastStats.getStat(StatNames.FORECAST_EXECUTE_REQUEST_COUNT.getName()).increment(); + + if (circuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(forecastID, CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + this.resultProcessor = new ForecastResultProcessor( + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityForecastResultAction.NAME, + StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + forecastStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + ForecastResultResponse.class, + featureManager, + AnalysisType.FORECAST, + false + ); + + try { + nodeStateManager + .getConfig( + forecastID, + AnalysisType.FORECAST, + resultProcessor.onGetConfig(listener, forecastID, request, Optional.of(hcForecasters)) + ); + } catch (Exception ex) { + ResultProcessor.handleExecuteException(ex, listener, forecastID); + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceAction.java new file mode 100644 index 000000000..addbf3216 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastRunOnceAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/runOnce"; + public static final ForecastRunOnceAction INSTANCE = new ForecastRunOnceAction(); + + private ForecastRunOnceAction() { + super(NAME, ForecastResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileAction.java new file mode 100644 index 000000000..1025d3dfc --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.BooleanResponse; + +public class ForecastRunOnceProfileAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/runOnceProfile"; + public static final ForecastRunOnceProfileAction INSTANCE = new ForecastRunOnceProfileAction(); + + private ForecastRunOnceProfileAction() { + super(NAME, BooleanResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileRequest.java new file mode 100644 index 000000000..67f807f50 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileRequest.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ForecastRunOnceProfileRequest extends BaseNodesRequest { + private String configId; + + public ForecastRunOnceProfileRequest(StreamInput in) throws IOException { + super(in); + configId = in.readString(); + } + + /** + * Constructor + * + * @param configId config id + */ + public ForecastRunOnceProfileRequest(String configId, DiscoveryNode... nodes) { + super(nodes); + this.configId = configId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configId); + } + + public String getConfigId() { + return configId; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java new file mode 100644 index 000000000..a9fe218a8 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceProfileTransportAction.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.BooleanNodeResponse; +import org.opensearch.timeseries.transport.BooleanResponse; +import org.opensearch.timeseries.transport.ForecastRunOnceProfileNodeRequest; +import org.opensearch.transport.TransportService; + +public class ForecastRunOnceProfileTransportAction extends + TransportNodesAction { + private final ForecastColdStartWorker coldStartWorker; + private final ForecastCheckpointReadWorker checkpointReadWorker; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param settings Node settings accessor + */ + @Inject + public ForecastRunOnceProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Settings settings, + ForecastColdStartWorker coldStartWorker, + ForecastCheckpointReadWorker checkpointReadWorker + ) { + super( + ForecastRunOnceProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ForecastRunOnceProfileRequest::new, + ForecastRunOnceProfileNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + BooleanNodeResponse.class + ); + this.coldStartWorker = coldStartWorker; + this.checkpointReadWorker = checkpointReadWorker; + } + + @Override + protected BooleanResponse newResponse( + ForecastRunOnceProfileRequest request, + List responses, + List failures + ) { + return new BooleanResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ForecastRunOnceProfileNodeRequest newNodeRequest(ForecastRunOnceProfileRequest request) { + return new ForecastRunOnceProfileNodeRequest(request); + } + + @Override + protected BooleanNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new BooleanNodeResponse(in); + } + + @Override + protected BooleanNodeResponse nodeOperation(ForecastRunOnceProfileNodeRequest request) { + String configId = request.getConfigId(); + + return new BooleanNodeResponse( + clusterService.localNode(), + coldStartWorker.hasConfigId(configId) || checkpointReadWorker.hasConfigId(configId) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java new file mode 100644 index 000000000..dcd132a05 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastRunOnceTransportAction.java @@ -0,0 +1,354 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.core.rest.RestStatus.CONFLICT; +import static org.opensearch.core.rest.RestStatus.FORBIDDEN; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_FORECAST_FEATURES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_HC_FORECASTERS; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public class ForecastRunOnceTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ForecastRunOnceTransportAction.class); + private ResultProcessor resultProcessor; + private final Client client; + private CircuitBreakerService circuitBreakerService; + private final NodeStateManager nodeStateManager; + + private final Settings settings; + private final ClusterService clusterService; + private final ThreadPool threadPool; + private final HashRing hashRing; + private final TransportService transportService; + private final ForecastTaskManager taskManager; + private final NamedXContentRegistry xContentRegistry; + private final SecurityClientUtil clientUtil; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final FeatureManager featureManager; + private final ForecastStats forecastStats; + private volatile Boolean filterByEnabled; + + protected volatile Integer maxSingleStreamForecasters; + protected volatile Integer maxHCForecasters; + protected volatile Integer maxForecastFeatures; + protected volatile Integer maxCategoricalFields; + + @Inject + public ForecastRunOnceTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + Client client, + SecurityClientUtil clientUtil, + NodeStateManager nodeStateManager, + FeatureManager featureManager, + ForecastModelManager modelManager, + HashRing hashRing, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver, + CircuitBreakerService circuitBreakerService, + ForecastStats forecastStats, + ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager realTimeTaskManager + ) { + super(ForecastRunOnceAction.NAME, transportService, actionFilters, ForecastResultRequest::new); + + this.resultProcessor = null; + this.settings = settings; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.hashRing = hashRing; + this.transportService = transportService; + this.taskManager = realTimeTaskManager; + this.xContentRegistry = xContentRegistry; + this.clientUtil = clientUtil; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.featureManager = featureManager; + this.forecastStats = forecastStats; + + this.client = client; + this.circuitBreakerService = circuitBreakerService; + this.nodeStateManager = nodeStateManager; + filterByEnabled = ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + + this.maxSingleStreamForecasters = MAX_SINGLE_STREAM_FORECASTERS.get(settings); + this.maxHCForecasters = MAX_HC_FORECASTERS.get(settings); + this.maxForecastFeatures = MAX_FORECAST_FEATURES; + this.maxCategoricalFields = ForecastNumericSetting.maxCategoricalFields(); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_SINGLE_STREAM_FORECASTERS, it -> maxSingleStreamForecasters = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_HC_FORECASTERS, it -> maxHCForecasters = it); + } + + @Override + protected void doExecute(Task task, ForecastResultRequest request, ActionListener listener) { + String forecastID = request.getConfigId(); + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + + resolveUserAndExecute( + user, + forecastID, + filterByEnabled, + listener, + (forecaster) -> executeRunOnce(forecastID, request, listener), + client, + clusterService, + xContentRegistry, + Forecaster.class + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(new OpenSearchStatusException("Failed to run once forecaster " + forecastID, INTERNAL_SERVER_ERROR)); + } + } + + private void executeRunOnce(String forecastID, ForecastResultRequest request, ActionListener listener) { + if (!ForecastEnabledSetting.isForecastEnabled()) { + listener.onFailure(new OpenSearchStatusException(ForecastCommonMessages.DISABLED_ERR_MSG, FORBIDDEN)); + } + + if (circuitBreakerService.isOpen()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, SERVICE_UNAVAILABLE)); + return; + } + + client.execute(ForecastRunOnceProfileAction.INSTANCE, new ForecastRunOnceProfileRequest(forecastID), ActionListener.wrap(r -> { + if (r.isAnswerTrue()) { + listener + .onFailure( + new OpenSearchStatusException( + "cannot start a new test " + forecastID + " since current test hasn't finished.", + CONFLICT + ) + ); + } else { + nodeStateManager.getJob(forecastID, ActionListener.wrap(jobOptional -> { + if (jobOptional.isPresent() && jobOptional.get().isEnabled()) { + listener + .onFailure( + new OpenSearchStatusException("Cannot run once " + forecastID + " when real time job is running.", CONFLICT) + ); + return; + } + + triggerRunOnce(forecastID, request, listener); + }, e -> { + if (e instanceof IndexNotFoundException) { + triggerRunOnce(forecastID, request, listener); + } else { + LOG.error(e); + listener + .onFailure(new OpenSearchStatusException("Fail to verify if job " + forecastID + " starts or not.", CONFLICT)); + } + })); + } + }, e -> { + LOG.error(e); + listener.onFailure(new OpenSearchStatusException("Failed to run once forecaster " + forecastID, INTERNAL_SERVER_ERROR)); + })); + } + + private void checkIfRunOnceFinished(String forecastID, String taskId, AtomicInteger waitTimes) { + client.execute(ForecastRunOnceProfileAction.INSTANCE, new ForecastRunOnceProfileRequest(forecastID), ActionListener.wrap(r -> { + if (r.isAnswerTrue()) { + handleRunOnceNotFinished(forecastID, taskId, waitTimes); + } else { + handleRunOnceFinished(forecastID, taskId); + } + }, e -> { + LOG.error("Failed to profile run once of forecaster " + forecastID, e); + handleRunOnceNotFinished(forecastID, taskId, waitTimes); + })); + } + + private void handleRunOnceNotFinished(String forecastID, String taskId, AtomicInteger waitTimes) { + if (waitTimes.get() < 10) { + waitTimes.addAndGet(1); + threadPool + .schedule( + () -> checkIfRunOnceFinished(forecastID, taskId, waitTimes), + new TimeValue(10, TimeUnit.SECONDS), + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME + ); + } else { + LOG.warn("Timed out run once of forecaster {}", forecastID); + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + } + } + + private void handleRunOnceFinished(String forecastID, String taskId) { + LOG.info("Run once of forecaster {} finished", forecastID); + nodeStateManager.getConfig(forecastID, AnalysisType.FORECAST, ActionListener.wrap(configOptional -> { + if (configOptional.isEmpty()) { + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + return; + } + checkForecastResults(forecastID, taskId, configOptional.get()); + }, e -> { + LOG.error("Fail to get config", e); + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + })); + } + + private void checkForecastResults(String forecastID, String taskId, Config config) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(ForecastCommonName.FORECASTER_ID_KEY, forecastID)); + ExistsQueryBuilder forecastsExistFilter = QueryBuilders.existsQuery(ForecastResult.VALUE_FIELD); + filterQuery.must(forecastsExistFilter); + // run-once analysis result also stored in result index, which has non-null task_id. + filterQuery.filter(QueryBuilders.termQuery(CommonName.TASK_ID_FIELD, taskId)); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN); + request.source(source); + if (config.getCustomResultIndex() != null) { + request.indices(config.getCustomResultIndex()); + } + + client.search(request, ActionListener.wrap(searchResponse -> { + SearchHits hits = searchResponse.getHits(); + if (hits.getTotalHits().value > 0) { + // has at least one result + updateTaskState(forecastID, taskId, TaskState.TEST_COMPLETE); + } else { + updateTaskState(forecastID, taskId, TaskState.INIT_TEST_FAILED); + } + }, e -> { + LOG.error("Fail to search result", e); + updateTaskState(forecastID, taskId, TaskState.INACTIVE); + })); + } + + private void updateTaskState(String forecastID, String taskId, TaskState state) { + taskManager.updateTask(taskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, state.name()), ActionListener.wrap(updateResponse -> { + LOG.info("Updated forecaster task: {} state as: {} for forecaster: {}", taskId, state.name(), forecastID); + }, e -> { LOG.error("Failed to update forecaster task: {} for forecaster: {}", taskId, forecastID, e); })); + } + + private void triggerRunOnce(String forecastID, ForecastResultRequest request, ActionListener listener) { + try { + resultProcessor = new ForecastResultProcessor( + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityForecastResultAction.NAME, + StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + forecastStats, + taskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + ForecastResultResponse.class, + featureManager, + AnalysisType.FORECAST, + true + ); + + ActionListener wrappedListener = ActionListener.wrap(r -> { + AtomicInteger waitTimes = new AtomicInteger(0); + + threadPool + .schedule( + () -> checkIfRunOnceFinished(forecastID, r.getTaskId(), waitTimes), + new TimeValue(10, TimeUnit.SECONDS), + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME + ); + listener.onResponse(r); + }, e -> { + LOG.error("Failed to finish run once of forecaster " + forecastID, e); + listener.onFailure(new OpenSearchStatusException("Failed to run once forecaster " + forecastID, INTERNAL_SERVER_ERROR)); + }); + + nodeStateManager + .getConfig( + forecastID, + AnalysisType.FORECAST, + resultProcessor.onGetConfig(wrappedListener, forecastID, request, Optional.empty()) + ); + + // check for status + } catch (Exception ex) { + ResultProcessor.handleExecuteException(ex, listener, forecastID); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java new file mode 100644 index 000000000..6b8687b82 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastSingleStreamResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "singlestream/result"; + public static final ForecastSingleStreamResultAction INSTANCE = new ForecastSingleStreamResultAction(); + + private ForecastSingleStreamResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java new file mode 100644 index 000000000..e051e878d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java @@ -0,0 +1,239 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.SingleStreamResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastSingleStreamResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ForecastSingleStreamResultTransportAction.class); + private CircuitBreakerService circuitBreakerService; + private ForecastCacheProvider cache; + private final NodeStateManager stateManager; + private ForecastCheckpointReadWorker checkpointReadQueue; + private ForecastModelManager modelManager; + private ForecastIndexManagement indexUtil; + private ForecastResultWriteWorker resultWriteQueue; + private ForecastStats stats; + private ForecastColdStartWorker forecastColdStartQueue; + + @Inject + public ForecastSingleStreamResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + CircuitBreakerService circuitBreakerService, + ForecastCacheProvider cache, + NodeStateManager stateManager, + ForecastCheckpointReadWorker checkpointReadQueue, + ForecastModelManager modelManager, + ForecastIndexManagement indexUtil, + ForecastResultWriteWorker resultWriteQueue, + ForecastStats stats, + ForecastColdStartWorker forecastColdStartQueue + ) { + super(ForecastSingleStreamResultAction.NAME, transportService, actionFilters, SingleStreamResultRequest::new); + this.circuitBreakerService = circuitBreakerService; + this.cache = cache; + this.stateManager = stateManager; + this.checkpointReadQueue = checkpointReadQueue; + this.modelManager = modelManager; + this.indexUtil = indexUtil; + this.resultWriteQueue = resultWriteQueue; + this.stats = stats; + this.forecastColdStartQueue = forecastColdStartQueue; + } + + @Override + protected void doExecute(Task task, SingleStreamResultRequest request, ActionListener listener) { + if (circuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String forecasterId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(forecasterId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", forecasterId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, forecasterId); + } + + stateManager.getConfig(forecasterId, AnalysisType.FORECAST, onGetConfig(listener, forecasterId, request, previousException)); + } catch (Exception exception) { + LOG.error("fail to get entity's forecasts", exception); + listener.onFailure(exception); + } + } + + public ActionListener> onGetConfig( + ActionListener listener, + String forecasterId, + SingleStreamResultRequest request, + Optional prevException + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(forecasterId, "Config " + forecasterId + " is not available.", false)); + return; + } + + Config config = configOptional.get(); + + Instant executionStartTime = Instant.now(); + + String modelId = request.getModelId(); + double[] datapoint = request.getDataPoint(); + ModelState modelState = cache.get().get(modelId, config); + if (modelState == null) { + // cache miss + checkpointReadQueue + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + request.getModelId(), + datapoint, + request.getStart(), + request.getTaskId() + ) + ); + } else { + try { + RCFCasterResult result = modelManager + .getResult( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + modelState, + modelId, + Optional.empty(), + config, + request.getTaskId() + ); + // result.getRcfScore() = 0 means the model is not initialized + if (result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + executionStartTime, + Instant.now(), + ParseUtils.getFeatureData(datapoint, config), + Optional.empty(), + indexUtil.getSchemaVersion(ForecastIndex.RESULT), + modelId, + null, + null + ); + + for (ForecastResult r : indexableResults) { + resultWriteQueue + .put( + ForecastResultWriteRequest + .create( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + r, + config.getCustomResultIndex(), + ForecastResultWriteRequest.class + ) + ); + } + } + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + stats.getStat(StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName()).increment(); + cache.get().removeModel(forecasterId, modelId); + forecastColdStartQueue + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + modelId, + datapoint, + request.getStart(), + request.getTaskId() + ) + ); + } + } + + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + }, exception -> { + LOG + .error( + new ParameterizedMessage( + "fail to get entity's forecasts for forecaster [{}]: start: [{}], end: [{}]", + forecasterId, + request.getStart(), + request.getEnd() + ), + exception + ); + listener.onFailure(exception); + }); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesAction.java new file mode 100644 index 000000000..3d1bd793e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesAction.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.StatsNodesResponse; + +/** + * ADStatsNodesAction class + */ +public class ForecastStatsNodesAction extends ActionType { + + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; + public static final ForecastStatsNodesAction INSTANCE = new ForecastStatsNodesAction(); + + /** + * Constructor + */ + private ForecastStatsNodesAction() { + super(NAME, StatsNodesResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesTransportAction.java new file mode 100644 index 000000000..9d2ec58d7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastStatsNodesTransportAction.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.transport.BaseStatsNodesTransportAction; +import org.opensearch.transport.TransportService; + +/** + * ForecastStatsNodesTransportAction contains the logic to extract the stats from the nodes + */ +public class ForecastStatsNodesTransportAction extends BaseStatsNodesTransportAction { + @Inject + public ForecastStatsNodesTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ForecastStats stats + ) { + super(threadPool, clusterService, transportService, actionFilters, stats, ForecastStatsNodesAction.NAME); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java b/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java new file mode 100644 index 000000000..bfd915288 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.JobResponse; + +public class ForecasterJobAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/jobmanagement"; + public static final ForecasterJobAction INSTANCE = new ForecasterJobAction(); + + private ForecasterJobAction() { + super(NAME, JobResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java new file mode 100644 index 000000000..b6f35c27f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_START_FORECASTER; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_STOP_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseJobTransportAction; +import org.opensearch.transport.TransportService; + +public class ForecasterJobTransportAction extends + BaseJobTransportAction { + + @Inject + public ForecasterJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + ForecastIndexJobActionHandler forecastIndexJobActionHandler + ) { + super( + transportService, + actionFilters, + client, + clusterService, + settings, + xContentRegistry, + FORECAST_FILTER_BY_BACKEND_ROLES, + ForecasterJobAction.NAME, + FORECAST_REQUEST_TIMEOUT, + FAIL_TO_START_FORECASTER, + FAIL_TO_STOP_FORECASTER, + Forecaster.class, + forecastIndexJobActionHandler + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java new file mode 100644 index 000000000..ef5d13540 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class GetForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecasters/get"; + public static final GetForecasterAction INSTANCE = new GetForecasterAction(); + + private GetForecasterAction() { + super(NAME, GetForecasterResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java new file mode 100644 index 000000000..5ac88a509 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java @@ -0,0 +1,220 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.ForecasterProfile; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class GetForecasterResponse extends ActionResponse implements ToXContentObject { + + public static final String FORECASTER_PROFILE = "forecasterProfile"; + public static final String ENTITY_PROFILE = "entityProfile"; + private String id; + private long version; + private long primaryTerm; + private long seqNo; + private Forecaster forecaster; + private Job forecastJob; + private ForecastTask realtimeTask; + private ForecastTask runOnceTask; + private RestStatus restStatus; + private ForecasterProfile forecasterProfile; + private EntityProfile entityProfile; + private boolean profileResponse; + private boolean returnJob; + private boolean returnTask; + + public GetForecasterResponse(StreamInput in) throws IOException { + super(in); + profileResponse = in.readBoolean(); + if (profileResponse) { + String profileType = in.readString(); + if (FORECASTER_PROFILE.equals(profileType)) { + forecasterProfile = new ForecasterProfile(in); + } else { + entityProfile = new EntityProfile(in); + } + } else { + id = in.readString(); + version = in.readLong(); + primaryTerm = in.readLong(); + seqNo = in.readLong(); + restStatus = in.readEnum(RestStatus.class); + forecaster = new Forecaster(in); + returnJob = in.readBoolean(); + if (returnJob) { + forecastJob = new Job(in); + } else { + forecastJob = null; + } + returnTask = in.readBoolean(); + if (in.readBoolean()) { + realtimeTask = new ForecastTask(in); + } else { + realtimeTask = null; + } + if (in.readBoolean()) { + runOnceTask = new ForecastTask(in); + } else { + runOnceTask = null; + } + } + + } + + public GetForecasterResponse( + String id, + long version, + long primaryTerm, + long seqNo, + Forecaster forecaster, + Job job, + boolean returnJob, + ForecastTask realtimeTask, + ForecastTask runOnceTask, + boolean returnTask, + RestStatus restStatus, + ForecasterProfile forecasterProfile, + EntityProfile entityProfile, + boolean profileResponse + ) { + this.id = id; + this.version = version; + this.primaryTerm = primaryTerm; + this.seqNo = seqNo; + this.forecaster = forecaster; + this.forecastJob = job; + this.returnJob = returnJob; + if (this.returnJob) { + this.forecastJob = job; + } else { + this.forecastJob = null; + } + this.returnTask = returnTask; + if (this.returnTask) { + this.realtimeTask = realtimeTask; + this.runOnceTask = runOnceTask; + } else { + this.realtimeTask = null; + this.runOnceTask = null; + } + this.restStatus = restStatus; + this.forecasterProfile = forecasterProfile; + this.entityProfile = entityProfile; + this.profileResponse = profileResponse; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (profileResponse) { + out.writeBoolean(true); // profileResponse is true + if (forecasterProfile != null) { + out.writeString(FORECASTER_PROFILE); + forecasterProfile.writeTo(out); + } else if (entityProfile != null) { + out.writeString(ENTITY_PROFILE); + entityProfile.writeTo(out); + } + } else { + out.writeBoolean(false); // profileResponse is false + out.writeString(id); + out.writeLong(version); + out.writeLong(primaryTerm); + out.writeLong(seqNo); + out.writeEnum(restStatus); + forecaster.writeTo(out); + if (returnJob) { + out.writeBoolean(true); // returnJob is true + forecastJob.writeTo(out); + } else { + out.writeBoolean(false); // returnJob is false + } + out.writeBoolean(returnTask); + if (realtimeTask != null) { + out.writeBoolean(true); + realtimeTask.writeTo(out); + } else { + out.writeBoolean(false); + } + if (runOnceTask != null) { + out.writeBoolean(true); + runOnceTask.writeTo(out); + } else { + out.writeBoolean(false); + } + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (profileResponse) { + if (forecasterProfile != null) { + forecasterProfile.toXContent(builder, params); + } else { + entityProfile.toXContent(builder, params); + } + } else { + builder.startObject(); + builder.field(RestHandlerUtils._ID, id); + builder.field(RestHandlerUtils._VERSION, version); + builder.field(RestHandlerUtils._PRIMARY_TERM, primaryTerm); + builder.field(RestHandlerUtils._SEQ_NO, seqNo); + builder.field(RestHandlerUtils.REST_STATUS, restStatus); + builder.field(RestHandlerUtils.FORECASTER, forecaster); + if (returnJob) { + builder.field(RestHandlerUtils.FORECASTER_JOB, forecastJob); + } + if (returnTask) { + builder.field(RestHandlerUtils.REALTIME_TASK, realtimeTask); + builder.field(RestHandlerUtils.RUN_ONCE_TASK, runOnceTask); + } + builder.endObject(); + } + return builder; + } + + public Job getForecastJob() { + return forecastJob; + } + + public ForecastTask getRealtimeTask() { + return realtimeTask; + } + + public ForecastTask getRunOnceTask() { + return runOnceTask; + } + + public Forecaster getForecaster() { + return forecaster; + } + + public ForecasterProfile getForecasterProfile() { + return forecasterProfile; + } + + public EntityProfile getEntityProfile() { + return entityProfile; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java new file mode 100644 index 000000000..08bfa79d0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java @@ -0,0 +1,150 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.util.Optional; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ForecastEntityProfileRunner; +import org.opensearch.forecast.ForecastProfileRunner; +import org.opensearch.forecast.ForecastTaskProfileRunner; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskProfile; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.ForecasterProfile; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class GetForecasterTransportAction extends + BaseGetConfigTransportAction { + + @Inject + public GetForecasterTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + Settings settings, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager forecastTaskManager, + ForecastTaskProfileRunner taskProfileRunner + ) { + super( + transportService, + nodeFilter, + actionFilters, + clusterService, + client, + clientUtil, + settings, + xContentRegistry, + forecastTaskManager, + GetForecasterAction.NAME, + Forecaster.class, + Forecaster.FORECAST_PARSE_FIELD_NAME, + ForecastTaskType.ALL_FORECAST_TASK_TYPES, + ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER.name(), + ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM.name(), + ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER.name(), + ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM.name(), + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + taskProfileRunner + ); + } + + @Override + protected GetForecasterResponse createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + Forecaster config, + Job job, + boolean returnJob, + Optional realtimeTask, + Optional historicalTask, + boolean returnTask, + RestStatus restStatus, + ForecasterProfile forecasterProfile, + EntityProfile entityProfile, + boolean profileResponse + ) { + return new GetForecasterResponse( + id, + version, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask.orElse(null), + historicalTask.orElse(null), + returnTask, + restStatus, + forecasterProfile, + entityProfile, + profileResponse + ); + } + + @Override + protected ForecastEntityProfileRunner createEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ) { + return new ForecastEntityProfileRunner(client, clientUtil, xContentRegistry, TimeSeriesSettings.NUM_MIN_SAMPLES); + } + + @Override + protected ForecastProfileRunner createProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + ForecastTaskManager taskManager, + ForecastTaskProfileRunner taskProfileRunner + ) { + return new ForecastProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + TimeSeriesSettings.NUM_MIN_SAMPLES, + transportService, + taskManager, + taskProfileRunner + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java new file mode 100644 index 000000000..23613a89f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class IndexForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/write"; + public static final IndexForecasterAction INSTANCE = new IndexForecasterAction(); + + private IndexForecasterAction() { + super(NAME, IndexForecasterResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java new file mode 100644 index 000000000..60a3a1964 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java @@ -0,0 +1,144 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.rest.RestRequest; + +public class IndexForecasterRequest extends ActionRequest { + private String forecastID; + private long seqNo; + private long primaryTerm; + private WriteRequest.RefreshPolicy refreshPolicy; + private Forecaster forecaster; + private RestRequest.Method method; + private TimeValue requestTimeout; + private Integer maxSingleStreamForecasters; + private Integer maxHCForecasters; + private Integer maxForecastFeatures; + private Integer maxCategoricalFields; + + public IndexForecasterRequest(StreamInput in) throws IOException { + super(in); + forecastID = in.readString(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + refreshPolicy = in.readEnum(WriteRequest.RefreshPolicy.class); + forecaster = new Forecaster(in); + method = in.readEnum(RestRequest.Method.class); + requestTimeout = in.readTimeValue(); + maxSingleStreamForecasters = in.readInt(); + maxHCForecasters = in.readInt(); + maxForecastFeatures = in.readInt(); + maxCategoricalFields = in.readInt(); + } + + public IndexForecasterRequest( + String forecasterID, + long seqNo, + long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Forecaster forecaster, + RestRequest.Method method, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + Integer maxCategoricalFields + ) { + super(); + this.forecastID = forecasterID; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.refreshPolicy = refreshPolicy; + this.forecaster = forecaster; + this.method = method; + this.requestTimeout = requestTimeout; + this.maxSingleStreamForecasters = maxSingleEntityAnomalyDetectors; + this.maxHCForecasters = maxMultiEntityAnomalyDetectors; + this.maxForecastFeatures = maxAnomalyFeatures; + this.maxCategoricalFields = maxCategoricalFields; + } + + public String getForecasterID() { + return forecastID; + } + + public long getSeqNo() { + return seqNo; + } + + public long getPrimaryTerm() { + return primaryTerm; + } + + public WriteRequest.RefreshPolicy getRefreshPolicy() { + return refreshPolicy; + } + + public Forecaster getForecaster() { + return forecaster; + } + + public RestRequest.Method getMethod() { + return method; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } + + public Integer getMaxSingleStreamForecasters() { + return maxSingleStreamForecasters; + } + + public Integer getMaxHCForecasters() { + return maxHCForecasters; + } + + public Integer getMaxForecastFeatures() { + return maxForecastFeatures; + } + + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(forecastID); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + out.writeEnum(refreshPolicy); + forecaster.writeTo(out); + out.writeEnum(method); + out.writeTimeValue(requestTimeout); + out.writeInt(maxSingleStreamForecasters); + out.writeInt(maxHCForecasters); + out.writeInt(maxForecastFeatures); + out.writeInt(maxCategoricalFields); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java similarity index 76% rename from src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java rename to src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java index f65c5a06b..85362a07d 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.forecast.transport; import java.io.IOException; @@ -19,29 +19,33 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.Forecaster; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyDetectorJobResponse extends ActionResponse implements ToXContentObject { +public class IndexForecasterResponse extends ActionResponse implements ToXContentObject { private final String id; private final long version; private final long seqNo; private final long primaryTerm; + private final Forecaster forecaster; private final RestStatus restStatus; - public AnomalyDetectorJobResponse(StreamInput in) throws IOException { + public IndexForecasterResponse(StreamInput in) throws IOException { super(in); id = in.readString(); version = in.readLong(); seqNo = in.readLong(); primaryTerm = in.readLong(); + forecaster = new Forecaster(in); restStatus = in.readEnum(RestStatus.class); } - public AnomalyDetectorJobResponse(String id, long version, long seqNo, long primaryTerm, RestStatus restStatus) { + public IndexForecasterResponse(String id, long version, long seqNo, long primaryTerm, Forecaster forecaster, RestStatus restStatus) { this.id = id; this.version = version; this.seqNo = seqNo; this.primaryTerm = primaryTerm; + this.forecaster = forecaster; this.restStatus = restStatus; } @@ -55,6 +59,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(version); out.writeLong(seqNo); out.writeLong(primaryTerm); + forecaster.writeTo(out); out.writeEnum(restStatus); } @@ -65,6 +70,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(RestHandlerUtils._ID, id) .field(RestHandlerUtils._VERSION, version) .field(RestHandlerUtils._SEQ_NO, seqNo) + .field(RestHandlerUtils.FORECASTER, forecaster) .field(RestHandlerUtils._PRIMARY_TERM, primaryTerm) .endObject(); } diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java new file mode 100644 index 000000000..c9bc28b72 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java @@ -0,0 +1,223 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_CREATE_FORECASTER; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_UPDATE_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; +import static org.opensearch.timeseries.util.ParseUtils.getConfig; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.List; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.rest.handler.IndexForecasterActionHandler; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class IndexForecasterTransportAction extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(IndexForecasterTransportAction.class); + private final Client client; + private final SecurityClientUtil clientUtil; + private final TransportService transportService; + private final ForecastIndexManagement forecastIndices; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private final SearchFeatureDao searchFeatureDao; + private final ForecastTaskManager taskManager; + private final Settings settings; + + @Inject + public IndexForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ForecastIndexManagement forecastIndices, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + ForecastTaskManager taskManager + ) { + super(IndexForecasterAction.NAME, transportService, actionFilters, IndexForecasterRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.transportService = transportService; + this.clusterService = clusterService; + this.forecastIndices = forecastIndices; + this.xContentRegistry = xContentRegistry; + filterByEnabled = ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.searchFeatureDao = searchFeatureDao; + this.taskManager = taskManager; + this.settings = settings; + } + + @Override + protected void doExecute(Task task, IndexForecasterRequest request, ActionListener actionListener) { + User user = getUserContext(client); + String forecasterId = request.getForecasterID(); + RestRequest.Method method = request.getMethod(); + String errorMessage = method == RestRequest.Method.PUT ? FAIL_TO_UPDATE_FORECASTER : FAIL_TO_CREATE_FORECASTER; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + forecasterId, + method, + listener, + (forecaster) -> forecastExecute(request, user, forecaster, context, listener) + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void resolveUserAndExecute( + User requestedUser, + String forecasterId, + RestRequest.Method method, + ActionListener listener, + Consumer function + ) { + try { + // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to + // check if request user have access to the forecaster or not. But we still need to get current forecaster for + // this case, so we can keep current forecaster's user data. + boolean filterByBackendRole = requestedUser == null ? false : filterByEnabled; + + // Check if user has backend roles + // When filter by is enabled, block users creating/updating detectors who do not have backend roles. + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new IllegalArgumentException(error)); + return; + } + } + if (method == RestRequest.Method.PUT) { + // Update forecaster request, check if user has permissions to update the forecaster + // Get forecaster and verify backend roles + getConfig( + requestedUser, + forecasterId, + listener, + function, + client, + clusterService, + xContentRegistry, + filterByBackendRole, + Forecaster.class + ); + } else { + // Create Detector. No need to get current detector. + function.accept(null); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void forecastExecute( + IndexForecasterRequest request, + User user, + Forecaster currentForecaster, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + forecastIndices.update(); + String forecasterId = request.getForecasterID(); + long seqNo = request.getSeqNo(); + long primaryTerm = request.getPrimaryTerm(); + WriteRequest.RefreshPolicy refreshPolicy = request.getRefreshPolicy(); + Forecaster forecaster = request.getForecaster(); + RestRequest.Method method = request.getMethod(); + TimeValue requestTimeout = request.getRequestTimeout(); + Integer maxSingleStreamForecasters = request.getMaxSingleStreamForecasters(); + Integer maxHCForecasters = request.getMaxHCForecasters(); + Integer maxForecastFeatures = request.getMaxForecastFeatures(); + Integer maxCategoricalFields = request.getMaxCategoricalFields(); + + storedContext.restore(); + checkIndicesAndExecute(forecaster.getIndices(), () -> { + // Don't replace forecaster's user when update detector + // Github issue: https://github.com/opensearch-project/anomaly-detection/issues/124 + User forecastUser = currentForecaster == null ? user : currentForecaster.getUser(); + IndexForecasterActionHandler indexForecasterActionHandler = new IndexForecasterActionHandler( + clusterService, + client, + clientUtil, + transportService, + forecastIndices, + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry, + forecastUser, + taskManager, + searchFeatureDao, + settings + ); + indexForecasterActionHandler.start(listener); + }, listener); + } + + private void checkIndicesAndExecute(List indices, ExecutorFunction function, ActionListener listener) { + SearchRequest searchRequest = new SearchRequest() + .indices(indices.toArray(new String[0])) + .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); + client.search(searchRequest, ActionListener.wrap(r -> { function.execute(); }, e -> { + // Due to below issue with security plugin, we get security_exception when invalid index name is mentioned. + // https://github.com/opendistro-for-elasticsearch/security/issues/718 + LOG.error(e); + listener.onFailure(e); + })); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/RelationalOperation.java b/src/main/java/org/opensearch/forecast/transport/RelationalOperation.java new file mode 100644 index 000000000..853767ba0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/RelationalOperation.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +public enum RelationalOperation { + GREATER_THAN(">"), + GREATER_THAN_OR_EQUAL_TO(">="), + LESS_THAN("<"), + LESS_THAN_OR_EQUAL_TO("<="); + + private final String symbol; + + RelationalOperation(String symbol) { + this.symbol = symbol; + } + + public String getSymbol() { + return this.symbol; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksAction.java new file mode 100644 index 000000000..1dda427a0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class SearchForecastTasksAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; + public static final SearchForecastTasksAction INSTANCE = new SearchForecastTasksAction(); + + private SearchForecastTasksAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksTransportAction.java new file mode 100644 index 000000000..5545e7668 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecastTasksTransportAction.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchForecastTasksTransportAction extends HandledTransportAction { + private ForecastSearchHandler searchHandler; + + @Inject + public SearchForecastTasksTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ForecastSearchHandler searchHandler + ) { + super(SearchForecastTasksAction.NAME, transportService, actionFilters, SearchRequest::new); + this.searchHandler = searchHandler; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener listener) { + searchHandler.search(request, listener); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterAction.java new file mode 100644 index 000000000..b4777a4b7 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class SearchForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecasters/search"; + public static final SearchForecasterAction INSTANCE = new SearchForecasterAction(); + + private SearchForecasterAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoAction.java new file mode 100644 index 000000000..ba5aec5cc --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.SearchConfigInfoResponse; + +public class SearchForecasterInfoAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/info"; + public static final SearchForecasterInfoAction INSTANCE = new SearchForecasterInfoAction(); + + private SearchForecasterInfoAction() { + super(NAME, SearchConfigInfoResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoTransportAction.java new file mode 100644 index 000000000..7131c65a4 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterInfoTransportAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.timeseries.transport.BaseSearchConfigInfoTransportAction; +import org.opensearch.transport.TransportService; + +public class SearchForecasterInfoTransportAction extends BaseSearchConfigInfoTransportAction { + + @Inject + public SearchForecasterInfoTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(transportService, actionFilters, client, SearchForecasterInfoAction.NAME); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchForecasterTransportAction.java new file mode 100644 index 000000000..b53d09b76 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchForecasterTransportAction.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchForecasterTransportAction extends HandledTransportAction { + private ForecastSearchHandler searchHandler; + + @Inject + public SearchForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ForecastSearchHandler searchHandler + ) { + super(SearchForecasterAction.NAME, transportService, actionFilters, SearchRequest::new); + this.searchHandler = searchHandler; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener listener) { + searchHandler.search(request, listener); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultAction.java new file mode 100644 index 000000000..831c42b69 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class SearchTopForecastResultAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "result/topForecasts"; + public static final SearchTopForecastResultAction INSTANCE = new SearchTopForecastResultAction(); + + private SearchTopForecastResultAction() { + super(NAME, SearchTopForecastResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultRequest.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultRequest.java new file mode 100644 index 000000000..d605d1ff8 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultRequest.java @@ -0,0 +1,448 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.FilterBy; +import org.opensearch.forecast.model.Subaggregation; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * Request for getting the top forecast results for HC forecasters. + *

+ * forecasterId, filterBy, and forecastFrom are required. + * One or two of buildInQuery, entity, threshold, filterQuery, subaggregations will be set to + * appropriate value depending on filterBy. + * Other parameters will be set to default values if left blank. + */ +public class SearchTopForecastResultRequest extends ActionRequest implements ToXContentObject { + + private static final String TASK_ID_FIELD = "task_id"; + private static final String SIZE_FIELD = "size"; + private static final String SPLIT_BY_FIELD = "split_by"; + private static final String FILTER_BY_FIELD = "filter_by"; + private static final String BUILD_IN_QUERY_FIELD = "build_in_query"; + private static final String THRESHOLD_FIELD = "threshold"; + private static final String RELATION_TO_THRESHOLD_FIELD = "relation_to_threshold"; + private static final String FILTER_QUERY_FIELD = "filter_query"; + public static final String SUBAGGREGATIONS_FIELD = "subaggregations"; + // forecast from looks for data end time + private static final String FORECAST_FROM_FIELD = "forecast_from"; + private static final String RUN_ONCE_FIELD = "run_once"; + + private String forecasterId; + private String taskId; + private boolean runOnce; + private Integer size; + private List splitBy; + private FilterBy filterBy; + private BuildInQuery buildInQuery; + private Float threshold; + private RelationalOperation relationToThreshold; + private QueryBuilder filterQuery; + private List subaggregations; + private Instant forecastFrom; + + public SearchTopForecastResultRequest(StreamInput in) throws IOException { + super(in); + forecasterId = in.readOptionalString(); + taskId = in.readOptionalString(); + runOnce = in.readBoolean(); + size = in.readOptionalInt(); + splitBy = in.readOptionalStringList(); + if (in.readBoolean()) { + filterBy = in.readEnum(FilterBy.class); + } else { + filterBy = null; + } + if (in.readBoolean()) { + buildInQuery = in.readEnum(BuildInQuery.class); + } else { + buildInQuery = null; + } + threshold = in.readOptionalFloat(); + if (in.readBoolean()) { + relationToThreshold = in.readEnum(RelationalOperation.class); + } else { + relationToThreshold = null; + } + if (in.readBoolean()) { + filterQuery = in.readNamedWriteable(QueryBuilder.class); + } else { + filterQuery = null; + } + if (in.readBoolean()) { + subaggregations = in.readList(Subaggregation::new); + } else { + subaggregations = null; + } + forecastFrom = in.readOptionalInstant(); + } + + public SearchTopForecastResultRequest( + String forecasterId, + String taskId, + boolean runOnce, + Integer size, + List splitBy, + FilterBy filterBy, + BuildInQuery buildInQuery, + Float threshold, + RelationalOperation relationToThreshold, + QueryBuilder filterQuery, + List subaggregations, + Instant forecastFrom + ) { + super(); + this.forecasterId = forecasterId; + this.taskId = taskId; + this.runOnce = runOnce; + this.size = size; + this.splitBy = splitBy; + this.filterBy = filterBy; + this.buildInQuery = buildInQuery; + this.threshold = threshold; + this.relationToThreshold = relationToThreshold; + this.filterQuery = filterQuery; + this.subaggregations = subaggregations; + this.forecastFrom = forecastFrom; + } + + public String getTaskId() { + return taskId; + } + + public boolean isRunOnce() { + return runOnce; + } + + public Integer getSize() { + return size; + } + + public String getForecasterId() { + return forecasterId; + } + + public List getSplitBy() { + return splitBy; + } + + public FilterBy getFilterBy() { + return filterBy; + } + + public BuildInQuery getBuildInQuery() { + return buildInQuery; + } + + public Float getThreshold() { + return threshold; + } + + public QueryBuilder getFilterQuery() { + return filterQuery; + } + + public List getSubaggregations() { + return subaggregations; + } + + public Instant getForecastFrom() { + return forecastFrom; + } + + public RelationalOperation getRelationToThreshold() { + return relationToThreshold; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public void setSize(Integer size) { + this.size = size; + } + + public void setForecasterId(String forecasterId) { + this.forecasterId = forecasterId; + } + + public void setRunOnce(boolean runOnce) { + this.runOnce = runOnce; + } + + public void setSplitBy(List splitBy) { + this.splitBy = splitBy; + } + + public void setFilterBy(FilterBy filterBy) { + this.filterBy = filterBy; + } + + public void setBuildInQuery(BuildInQuery buildInQuery) { + this.buildInQuery = buildInQuery; + } + + public void setThreshold(Float threshold) { + this.threshold = threshold; + } + + public void setFilterQuery(QueryBuilder filterQuery) { + this.filterQuery = filterQuery; + } + + public void setSubaggregations(List subaggregations) { + this.subaggregations = subaggregations; + } + + public void setForecastFrom(Instant forecastFrom) { + this.forecastFrom = forecastFrom; + } + + public void setRelationToThreshold(RelationalOperation relationToThreshold) { + this.relationToThreshold = relationToThreshold; + } + + public static SearchTopForecastResultRequest parse(XContentParser parser, String forecasterId) throws IOException { + String taskId = null; + Integer size = null; + List splitBy = null; + FilterBy filterBy = null; + BuildInQuery buildInQuery = null; + Float threshold = null; + RelationalOperation relationToThreshold = null; + QueryBuilder filterQuery = null; + List subaggregations = new ArrayList<>(); + Instant forecastFrom = null; + boolean runOnce = false; + + // "forecasterId" and "historical" params come from the original API path, not in the request body + // and therefore don't need to be parsed + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TASK_ID_FIELD: + taskId = parser.text(); + break; + case SIZE_FIELD: + size = parser.intValue(); + break; + case SPLIT_BY_FIELD: + splitBy = Arrays.asList(parser.text().split(",")); + break; + case FILTER_BY_FIELD: + filterBy = FilterBy.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case BUILD_IN_QUERY_FIELD: + buildInQuery = BuildInQuery.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case THRESHOLD_FIELD: + threshold = parser.floatValue(); + break; + case RELATION_TO_THRESHOLD_FIELD: + relationToThreshold = RelationalOperation.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case FILTER_QUERY_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + try { + filterQuery = parseInnerQueryBuilder(parser); + } catch (ParsingException | XContentParseException e) { + throw new ValidationException( + "Custom query error in data filter: " + e.getMessage(), + ValidationIssueType.FILTER_QUERY, + ValidationAspect.FORECASTER + ); + } catch (IllegalArgumentException e) { + if (!e.getMessage().contains("empty clause")) { + throw e; + } + } + break; + case SUBAGGREGATIONS_FIELD: + try { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + subaggregations.add(Subaggregation.parse(parser)); + } + } catch (Exception e) { + if (e instanceof ParsingException || e instanceof XContentParseException) { + throw new ValidationException( + "Custom query error: " + e.getMessage(), + ValidationIssueType.SUBAGGREGATION, + ValidationAspect.FORECASTER + ); + } + throw e; + } + break; + case FORECAST_FROM_FIELD: + forecastFrom = ParseUtils.toInstant(parser); + break; + case RUN_ONCE_FIELD: + runOnce = parser.booleanValue(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new SearchTopForecastResultRequest( + forecasterId, + taskId, + runOnce, + size, + splitBy, + filterBy, + buildInQuery, + threshold, + relationToThreshold, + filterQuery, + subaggregations, + forecastFrom + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // "forecasterId" and "historical" params come from the original API path, not in the request body + // and therefore don't need to be in the generated json + builder + .field(TASK_ID_FIELD, taskId) + .field(SPLIT_BY_FIELD, String.join(",", splitBy)) + .field(FILTER_BY_FIELD, filterBy.name()) + .field(RUN_ONCE_FIELD, runOnce); + + if (size != null) { + builder.field(SIZE_FIELD, size); + } + if (buildInQuery != null) { + builder.field(BUILD_IN_QUERY_FIELD, buildInQuery); + } + if (threshold != null) { + builder.field(THRESHOLD_FIELD, threshold); + } + if (relationToThreshold != null) { + builder.field(RELATION_TO_THRESHOLD_FIELD, relationToThreshold); + } + if (filterQuery != null) { + builder.field(FILTER_QUERY_FIELD, filterQuery); + } + if (subaggregations != null) { + builder.field(SUBAGGREGATIONS_FIELD, subaggregations.toArray()); + } + if (forecastFrom != null) { + builder.field(FORECAST_FROM_FIELD, forecastFrom.toString()); + } + + return builder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(forecasterId); + out.writeOptionalString(taskId); + out.writeBoolean(runOnce); + out.writeOptionalInt(size); + out.writeOptionalStringCollection(splitBy); + if (filterBy == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(filterBy); + } + if (buildInQuery == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(buildInQuery); + } + out.writeOptionalFloat(threshold); + if (relationToThreshold == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(relationToThreshold); + } + if (filterQuery == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeNamedWriteable(filterQuery); + } + if (subaggregations == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeList(subaggregations); + } + out.writeOptionalInstant(forecastFrom); + } + + @Override + public ActionRequestValidationException validate() { + if (forecasterId == null) { + return addValidationError("Cannot find forecasterId", null); + } + if (filterBy == null) { + return addValidationError("Must set filter_by", null); + } + if (forecastFrom == null) { + return addValidationError("Must set forecast_from with epoch of milliseconds", null); + } + if (!((filterBy == FilterBy.BUILD_IN_QUERY) == (buildInQuery != null))) { + throw new IllegalArgumentException( + "If 'filter_by' is set to BUILD_IN_QUERY, a 'build_in_query' type must be provided. Otherwise, 'build_in_query' should not be given." + ); + } + + if (filterBy == FilterBy.BUILD_IN_QUERY + && buildInQuery == BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE + && (threshold == null || relationToThreshold == null)) { + return addValidationError( + String + .format(Locale.ROOT, "Must set threshold and relation_to_threshold, but get %s and %s", threshold, relationToThreshold), + null + ); + } + if (filterBy == FilterBy.CUSTOM_QUERY && (subaggregations == null || subaggregations.isEmpty())) { + return addValidationError("Must set subaggregations", null); + } + if (!runOnce && !Strings.isNullOrEmpty(taskId)) { + return addValidationError("task id must not be set when run_once is false", null); + } + return null; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultResponse.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultResponse.java new file mode 100644 index 000000000..ac4e5ed56 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultResponse.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastResultBucket; + +/** + * Response for getting the top anomaly results for HC detectors + */ +public class SearchTopForecastResultResponse extends ActionResponse implements ToXContentObject { + public static final String BUCKETS_FIELD = "buckets"; + + private List forecastResultBuckets; + + public SearchTopForecastResultResponse(StreamInput in) throws IOException { + super(in); + forecastResultBuckets = in.readList(ForecastResultBucket::new); + } + + public SearchTopForecastResultResponse(List forecastResultBuckets) { + this.forecastResultBuckets = forecastResultBuckets; + } + + public List getForecastResultBuckets() { + return forecastResultBuckets; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(forecastResultBuckets); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // no need to show bucket with empty value + return builder.startObject().field(BUCKETS_FIELD, forecastResultBuckets).endObject(); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultTransportAction.java new file mode 100644 index 000000000..069914f86 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SearchTopForecastResultTransportAction.java @@ -0,0 +1,605 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.indices.ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.routing.Preference; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.model.FilterBy; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastResultBucket; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.model.Order; +import org.opensearch.forecast.model.Subaggregation; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.QueryUtil; +import org.opensearch.transport.TransportService; + +/** + * Transport action to fetch top forecast results for HC forecaster. + */ +public class SearchTopForecastResultTransportAction extends + HandledTransportAction { + private static final Logger logger = LogManager.getLogger(SearchTopForecastResultTransportAction.class); + private ForecastSearchHandler searchHandler; + // Number of buckets to return per page + private static final String defaultIndex = ALL_FORECAST_RESULTS_INDEX_PATTERN; + + private static final int DEFAULT_SIZE = 5; + private static final int MAX_SIZE = 50; + + protected static final String AGG_NAME_TERM = "term_agg"; + + private final Client client; + private NamedXContentRegistry xContent; + + @Inject + public SearchTopForecastResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ForecastSearchHandler searchHandler, + Client client, + NamedXContentRegistry xContent + ) { + super(SearchTopForecastResultAction.NAME, transportService, actionFilters, SearchTopForecastResultRequest::new); + this.searchHandler = searchHandler; + this.client = client; + this.xContent = xContent; + } + + @Override + protected void doExecute(Task task, SearchTopForecastResultRequest request, ActionListener listener) { + GetConfigRequest getForecasterRequest = new GetConfigRequest( + request.getForecasterId(), + // The default version value used in + // org.opensearch.rest.action.RestActions.parseVersion() + -3L, + false, + true, + "", + "", + false, + null + ); + + client.execute(GetForecasterAction.INSTANCE, getForecasterRequest, ActionListener.wrap(getForecasterResponse -> { + // Make sure forecaster exists + if (getForecasterResponse.getForecaster() == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "No forecaster found with ID %s", request.getForecasterId())); + } + + Forecaster forecaster = getForecasterResponse.getForecaster(); + // Make sure forecaster is HC + List categoryFields = forecaster.getCategoryFields(); + if (categoryFields == null || categoryFields.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "No category fields found for forecaster ID %s", request.getForecasterId()) + ); + } + + // Validating the category fields. Setting the list to be all category fields, + // unless otherwise specified + if (request.getSplitBy() == null || request.getSplitBy().isEmpty()) { + request.setSplitBy(categoryFields); + } else { + for (String categoryField : request.getSplitBy()) { + if (!categoryFields.contains(categoryField)) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Category field %s doesn't exist for forecaster ID %s", + categoryField, + request.getForecasterId() + ) + ); + } + } + } + + // Validating run once tasks if runOnce is true. Setting the task id to the + // latest run once task's ID, unless otherwise specified + if (request.isRunOnce() == true && Strings.isNullOrEmpty(request.getTaskId())) { + ForecastTask runOnceTask = getForecasterResponse.getRunOnceTask(); + if (runOnceTask == null) { + throw new ResourceNotFoundException( + String.format(Locale.ROOT, "No latest run once tasks found for forecaster ID %s", request.getForecasterId()) + ); + } + request.setTaskId(runOnceTask.getTaskId()); + } + + // Validating the size. If nothing passed use default + if (request.getSize() == null) { + request.setSize(DEFAULT_SIZE); + } else if (request.getSize() > MAX_SIZE) { + throw new IllegalArgumentException("Size cannot exceed " + MAX_SIZE); + } else if (request.getSize() <= 0) { + throw new IllegalArgumentException("Size must be a positive integer"); + } + + // Generating the search request which will contain the generated query + SearchRequest searchRequest = generateQuery(request, forecaster); + + // Adding search over any custom result indices + if (!Strings.isNullOrEmpty(forecaster.getCustomResultIndex())) { + searchRequest.indices(forecaster.getCustomResultIndex()); + } + // Utilizing the existing search() from SearchHandler to handle security + // permissions. Both user role + // and backend role filtering is handled in there, and any error will be + // propagated up and + // returned as a failure in this Listener. + // This same method is used for security handling for the search results action. + // Since this action + // is doing fundamentally the same thing, we can reuse the security logic here. + searchHandler.search(searchRequest, onSearchResponse(request, categoryFields, forecaster, listener)); + }, exception -> { + logger.error("Failed to get top forecast results", exception); + listener.onFailure(exception); + })); + + } + + private ActionListener onSearchResponse( + SearchTopForecastResultRequest request, + List categoryFields, + Forecaster forecaster, + ActionListener listener + ) { + return ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + // empty result (e.g., cannot find forecasts within [forecast from, forecast from + horizon * interval] range). + listener.onResponse(new SearchTopForecastResultResponse(new ArrayList<>())); + return; + } + + Aggregation aggResults = aggs.get(AGG_NAME_TERM); + if (aggResults == null) { + // empty result + listener.onResponse(new SearchTopForecastResultResponse(new ArrayList<>())); + return; + } + + List buckets = ((MultiBucketsAggregation) aggResults).getBuckets(); + if (buckets == null || buckets.size() == 0) { + // empty result + listener + .onFailure( + new ResourceNotFoundException( + "No forecast value found. forecast_from timestamp or other parameters might be incorrect." + ) + ); + return; + } + + final GroupedActionListener groupListeneer = new GroupedActionListener<>(ActionListener.wrap(r -> { + // Keep original bucket order + // Sort the collection based on getBucketIndex() in ascending order + // and convert it to a List + List sortedList = r + .stream() + .sorted((a, b) -> Integer.compare(a.getBucketIndex(), b.getBucketIndex())) + .collect(Collectors.toList()); + listener.onResponse(new SearchTopForecastResultResponse(new ArrayList<>(sortedList))); + }, exception -> { + logger.warn("Failed to find valid aggregation result", exception); + listener + .onFailure(new OpenSearchStatusException("Failed to find valid aggregation result", RestStatus.INTERNAL_SERVER_ERROR)); + }), buckets.size()); + + for (int i = 0; i < buckets.size(); i++) { + MultiBucketsAggregation.Bucket bucket = buckets.get(i); + createForecastResultBucket(bucket, i, request, categoryFields, forecaster, groupListeneer); + } + }, e -> listener.onFailure(e)); + } + + public void createForecastResultBucket( + MultiBucketsAggregation.Bucket bucket, + int bucketIndex, + SearchTopForecastResultRequest request, + List categoryFields, + Forecaster forecaster, + ActionListener listener + ) { + Map aggregationsMap = new HashMap<>(); + for (Aggregation aggregation : bucket.getAggregations()) { + if (!(aggregation instanceof NumericMetricsAggregation.SingleValue)) { + listener + .onFailure( + new IllegalArgumentException( + String.format(Locale.ROOT, "A single value aggregation is required; received [{}]", aggregation) + ) + ); + } + NumericMetricsAggregation.SingleValue singleValueAggregation = (NumericMetricsAggregation.SingleValue) aggregation; + aggregationsMap.put(aggregation.getName(), singleValueAggregation.value()); + } + if (bucket instanceof Terms.Bucket) { + // our terms key is string + convertToCategoricalFieldValuePair( + (String) bucket.getKey(), + bucketIndex, + (int) bucket.getDocCount(), + aggregationsMap, + request, + categoryFields, + forecaster, + listener + ); + } else { + listener + .onFailure( + new IllegalArgumentException(String.format(Locale.ROOT, "We only use terms aggregation in top, but got %s", bucket)) + ); + } + } + + private void convertToCategoricalFieldValuePair( + String keyInSearchResponse, + int bucketIndex, + int docCount, + Map aggregations, + SearchTopForecastResultRequest request, + List categoryFields, + Forecaster forecaster, + ActionListener listener + ) { + List splitBy = request.getSplitBy(); + Map keys = new HashMap<>(); + // TODO: we only support two categorical fields. Expand to support more categorical fields + if (splitBy == null || splitBy.size() == categoryFields.size()) { + // use all categorical fields in splitBy. Convert entity id to concrete attributes. + findMatchingCategoricalFieldValuePair(keyInSearchResponse, docCount, aggregations, bucketIndex, forecaster, listener); + } else { + keys.put(splitBy.get(0), keyInSearchResponse); + listener.onResponse(new ForecastResultBucket(keys, docCount, aggregations, bucketIndex)); + } + } + + private void findMatchingCategoricalFieldValuePair( + String entityId, + int docCount, + Map aggregations, + int bucketIndex, + Forecaster forecaster, + ActionListener listener + ) { + TermQueryBuilder entityIdFilter = QueryBuilders.termQuery(CommonName.ENTITY_ID_FIELD, entityId); + + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(entityIdFilter); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery).size(1); + + String resultIndex = Strings.isNullOrEmpty(forecaster.getCustomResultIndex()) ? defaultIndex : forecaster.getCustomResultIndex(); + SearchRequest searchRequest = new SearchRequest() + .indices(resultIndex) + .source(searchSourceBuilder) + .preference(Preference.LOCAL.toString()); + + String failure = String.format(Locale.ROOT, "Cannot find a result matching entity id %s", entityId); + + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + try { + SearchHit[] hits = searchResponse.getHits().getHits(); + if (hits.length == 0) { + listener.onFailure(new IllegalArgumentException(failure)); + return; + } + SearchHit searchHit = hits[0]; + try (XContentParser parser = createXContentParserFromRegistry(xContent, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Optional entity = ForecastResult.parse(parser).getEntity(); + if (entity.isEmpty()) { + listener.onFailure(new IllegalArgumentException(failure)); + return; + } + + listener + .onResponse( + new ForecastResultBucket(convertMap(entity.get().getAttributes()), docCount, aggregations, bucketIndex) + ); + } catch (Exception e) { + listener.onFailure(new IllegalArgumentException(failure, e)); + } + } catch (Exception e) { + listener.onFailure(new IllegalArgumentException(failure, e)); + } + }, e -> listener.onFailure(new IllegalArgumentException(failure, e))); + + searchHandler.search(searchRequest, searchResponseListener); + } + + private Map convertMap(Map stringMap) { + // Create a new Map and copy the entries + Map objectMap = new HashMap<>(); + for (Map.Entry entry : stringMap.entrySet()) { + objectMap.put(entry.getKey(), entry.getValue()); + } + return objectMap; + } + + /** + * Generates the entire search request to pass to the search handler + * + * @param request the request containing the all of the user-specified + * parameters needed to generate the request + * @param forecaster Forecaster config + * @return the SearchRequest to pass to the SearchHandler + */ + private SearchRequest generateQuery(SearchTopForecastResultRequest request, Forecaster forecaster) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + QueryBuilder rangeQuery = generateDateFilter(request, forecaster); + boolQueryBuilder = boolQueryBuilder.filter(rangeQuery); + + // we only look for documents containing forecasts + boolQueryBuilder.filter(new ExistsQueryBuilder(ForecastResult.VALUE_FIELD)); + + FilterBy filterBy = request.getFilterBy(); + switch (filterBy) { + case CUSTOM_QUERY: + if (request.getFilterQuery() != null) { + boolQueryBuilder = boolQueryBuilder.filter(request.getFilterQuery()); + } + break; + case BUILD_IN_QUERY: + QueryBuilder buildInSubFilter = generateBuildInSubFilter(request, forecaster); + if (buildInSubFilter != null) { + boolQueryBuilder = boolQueryBuilder.filter(buildInSubFilter); + } + break; + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected filter by %s", request.getFilterBy())); + } + + boolQueryBuilder = generateTaskIdFilter(request, boolQueryBuilder); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).trackTotalHits(false).size(0); + + AggregationBuilder termsAgg = generateTermsAggregation(request, forecaster); + if (termsAgg != null) { + searchSourceBuilder = searchSourceBuilder.aggregation(termsAgg); + } + return new SearchRequest().indices(defaultIndex).source(searchSourceBuilder); + } + + private QueryBuilder generateBuildInSubFilter(SearchTopForecastResultRequest request, Forecaster forecaster) { + BuildInQuery buildInQuery = request.getBuildInQuery(); + switch (buildInQuery) { + case MIN_CONFIDENCE_INTERVAL_WIDTH: + case MAX_CONFIDENCE_INTERVAL_WIDTH: + // Include only documents where horizon_index is configured horizon (indicating the "latest" forecast). + return QueryBuilders.termQuery(ForecastResult.HORIZON_INDEX_FIELD, forecaster.getHorizon()); + case DISTANCE_TO_THRESHOLD_VALUE: + RangeQueryBuilder res = QueryBuilders.rangeQuery(ForecastResult.VALUE_FIELD); + Float threshold = request.getThreshold(); + switch (request.getRelationToThreshold()) { + case GREATER_THAN: + res = res.gt(threshold); + break; + case GREATER_THAN_OR_EQUAL_TO: + res = res.gte(threshold); + break; + case LESS_THAN: + res = res.lt(threshold); + break; + case LESS_THAN_OR_EQUAL_TO: + res = res.lte(threshold); + break; + } + return res; + default: + // no need to generate filter in cases like MIN_VALUE_WITHIN_THE_HORIZON + return null; + } + } + + /** + * Adding the date filter (needed regardless of filter by type) + * @param request top forecaster request + * @return filter for date + */ + private RangeQueryBuilder generateDateFilter(SearchTopForecastResultRequest request, Forecaster forecaster) { + // forecast from is data end time for forecast + // return QueryBuilders.termQuery(CommonName.DATA_END_TIME_FIELD, request.getForecastFrom().toEpochMilli()); + long startInclusive = request.getForecastFrom().toEpochMilli(); + long endExclusive = startInclusive + forecaster.getIntervalInMilliseconds(); + return QueryBuilders.rangeQuery(CommonName.DATA_END_TIME_FIELD).gte(startInclusive).lt(endExclusive); + } + + /** + * Generates the query with appropriate filters on the results indices. If + * fetching real-time results: must_not filter on task_id (because real-time + * results don't have a 'task_id' field associated with them in the document). + * If fetching historical results: term filter on the task_id. + * + * @param request the request containing the necessary fields to generate the query + * @param query Bool query to generate + * @return input bool query with added id related filter + */ + private BoolQueryBuilder generateTaskIdFilter(SearchTopForecastResultRequest request, BoolQueryBuilder query) { + if (!Strings.isNullOrEmpty(request.getTaskId())) { + query.filter(QueryBuilders.termQuery(CommonName.TASK_ID_FIELD, request.getTaskId())); + } else { + TermQueryBuilder forecasterIdFilter = QueryBuilders.termQuery(ForecastCommonName.FORECASTER_ID_KEY, request.getForecasterId()); + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + query.filter(forecasterIdFilter).mustNot(taskIdExistsFilter); + } + return query; + } + + /** + * Generates aggregation. Creating a list of sources based on the + * set of category fields, and sorting on the returned result buckets + * + * @param request the request containing the necessary fields to generate the + * aggregation + * @return the generated aggregation as an AggregationBuilder + */ + private TermsAggregationBuilder generateTermsAggregation(SearchTopForecastResultRequest request, Forecaster forecaster) { + // TODO: use multi_terms or composite when multiple categorical fields are required. + // Right now, since we only support two categorical fields, we either use terms + // aggregation for one categorical field or terms aggregation on entity_id for + // all categorical fields. + TermsAggregationBuilder termsAgg = AggregationBuilders.terms(AGG_NAME_TERM).size(request.getSize()); + + if (request.getSplitBy().size() == forecaster.getCategoryFields().size()) { + termsAgg = termsAgg.field(CommonName.ENTITY_ID_FIELD); + } else if (request.getSplitBy().size() == 1) { + termsAgg = termsAgg.script(QueryUtil.getScriptForCategoryField(request.getSplitBy().get(0))); + } + + List orders = new ArrayList<>(); + + FilterBy filterBy = request.getFilterBy(); + switch (filterBy) { + case BUILD_IN_QUERY: + Pair aggregationOrderPair = generateBuildInSubAggregation(request); + termsAgg.subAggregation(aggregationOrderPair.getLeft()); + orders.add(aggregationOrderPair.getRight()); + break; + case CUSTOM_QUERY: + // if customers defined customized aggregation + for (Subaggregation subaggregation : request.getSubaggregations()) { + AggregatorFactories.Builder internalAgg; + try { + internalAgg = ParseUtils.parseAggregators(subaggregation.getAggregation().toString(), xContent, null); + AggregationBuilder aggregation = internalAgg.getAggregatorFactories().iterator().next(); + termsAgg.subAggregation(aggregation); + orders.add(BucketOrder.aggregation(aggregation.getName(), subaggregation.getOrder() == Order.ASC ? true : false)); + } catch (IOException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unexpected IOException when parsing %s", subaggregation), + e + ); + } + } + break; + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected filter by %s", filterBy)); + } + + if (orders.isEmpty()) { + throw new IllegalArgumentException("Cannot have empty order list"); + } + + termsAgg.order(orders); + + return termsAgg; + } + + private Pair generateBuildInSubAggregation(SearchTopForecastResultRequest request) { + String aggregationName = null; + AggregationBuilder aggregation = null; + BucketOrder order = null; + BuildInQuery buildInQuery = request.getBuildInQuery(); + switch (buildInQuery) { + case MIN_CONFIDENCE_INTERVAL_WIDTH: + aggregationName = BuildInQuery.MIN_CONFIDENCE_INTERVAL_WIDTH.name(); + aggregation = AggregationBuilders.min(aggregationName).field(ForecastResult.INTERVAL_WIDTH_FIELD); + order = BucketOrder.aggregation(aggregationName, true); + return Pair.of(aggregation, order); + case MAX_CONFIDENCE_INTERVAL_WIDTH: + aggregationName = BuildInQuery.MAX_CONFIDENCE_INTERVAL_WIDTH.name(); + aggregation = AggregationBuilders.max(aggregationName).field(ForecastResult.INTERVAL_WIDTH_FIELD); + order = BucketOrder.aggregation(aggregationName, false); + return Pair.of(aggregation, order); + case MIN_VALUE_WITHIN_THE_HORIZON: + aggregationName = BuildInQuery.MIN_VALUE_WITHIN_THE_HORIZON.name(); + aggregation = AggregationBuilders.min(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, true); + return Pair.of(aggregation, order); + case MAX_VALUE_WITHIN_THE_HORIZON: + aggregationName = BuildInQuery.MAX_VALUE_WITHIN_THE_HORIZON.name(); + aggregation = AggregationBuilders.max(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, false); + return Pair.of(aggregation, order); + case DISTANCE_TO_THRESHOLD_VALUE: + RelationalOperation relationToThreshold = request.getRelationToThreshold(); + switch (relationToThreshold) { + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL_TO: + aggregationName = BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE.name(); + aggregation = AggregationBuilders.max(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, false); + return Pair.of(aggregation, order); + case LESS_THAN: + case LESS_THAN_OR_EQUAL_TO: + aggregationName = BuildInQuery.DISTANCE_TO_THRESHOLD_VALUE.name(); + aggregation = AggregationBuilders.min(aggregationName).field(ForecastResult.VALUE_FIELD); + order = BucketOrder.aggregation(aggregationName, true); + return Pair.of(aggregation, order); + default: + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unexpected relation to threshold %s", relationToThreshold) + ); + } + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unexpected build in query type %s", buildInQuery)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/StatsForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/StatsForecasterAction.java new file mode 100644 index 000000000..951850a67 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StatsForecasterAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; + +public class StatsForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/stats"; + public static final StatsForecasterAction INSTANCE = new StatsForecasterAction(); + + private StatsForecasterAction() { + super(NAME, StatsTimeSeriesResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/StatsForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/StatsForecasterTransportAction.java new file mode 100644 index 000000000..d2c6d3619 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StatsForecasterTransportAction.java @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.SingleBucketAggregation; +import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.BaseStatsTransportAction; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.StatsResponse; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.transport.TransportService; + +public class StatsForecasterTransportAction extends BaseStatsTransportAction { + public final Logger logger = LogManager.getLogger(StatsForecasterTransportAction.class); + private final String WITH_CATEGORY_FIELD = "with_category_field"; + private final String WITHOUT_CATEGORY_FIELD = "without_category_field"; + + @Inject + public StatsForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ForecastStats stats, + ClusterService clusterService + + ) { + super(transportService, actionFilters, client, stats, clusterService, StatsForecasterAction.NAME); + } + + /** + * Make async request to get the number of detectors in AnomalyDetector.ANOMALY_DETECTORS_INDEX if necessary + * and, onResponse, gather the cluster statistics + * + * @param client Client + * @param listener MultiResponsesDelegateActionListener to be used once both requests complete + * @param statsRequest Request containing stats to be retrieved + */ + @Override + public void getClusterStats(Client client, MultiResponsesDelegateActionListener listener, StatsRequest statsRequest) { + StatsResponse adStatsResponse = new StatsResponse(); + if ((statsRequest.getStatsToBeRetrieved().contains(StatNames.FORECASTER_COUNT.getName()) + || statsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName()) + || statsRequest.getStatsToBeRetrieved().contains(StatNames.HC_FORECASTER_COUNT.getName())) + && clusterService.state().getRoutingTable().hasIndex(CommonName.CONFIG_INDEX)) { + + // Create the query + ExistsQueryBuilder existsQuery = QueryBuilders.existsQuery(Config.CATEGORY_FIELD); + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery().mustNot(existsQuery); + + FilterAggregationBuilder withFieldAgg = AggregationBuilders.filter(WITH_CATEGORY_FIELD, existsQuery); + FilterAggregationBuilder withoutFieldAgg = AggregationBuilders.filter(WITHOUT_CATEGORY_FIELD, boolQuery); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + searchSourceBuilder.size(0); + searchSourceBuilder.aggregation(withFieldAgg); + searchSourceBuilder.aggregation(withoutFieldAgg); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX); + searchRequest.source(searchSourceBuilder); + + // Execute the query + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + // Parse the response + SingleBucketAggregation withField = (SingleBucketAggregation) searchResponse.getAggregations().get(WITH_CATEGORY_FIELD); + SingleBucketAggregation withoutField = (SingleBucketAggregation) searchResponse + .getAggregations() + .get(WITHOUT_CATEGORY_FIELD); + if (statsRequest.getStatsToBeRetrieved().contains(StatNames.FORECASTER_COUNT.getName())) { + stats.getStat(StatNames.FORECASTER_COUNT.getName()).setValue(withField.getDocCount() + withoutField.getDocCount()); + } + if (statsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName())) { + stats.getStat(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName()).setValue(withoutField.getDocCount()); + } + if (statsRequest.getStatsToBeRetrieved().contains(StatNames.HC_FORECASTER_COUNT.getName())) { + stats.getStat(StatNames.HC_FORECASTER_COUNT.getName()).setValue(withField.getDocCount()); + } + adStatsResponse.setClusterStats(getClusterStatsMap(statsRequest)); + listener.onResponse(adStatsResponse); + }, e -> listener.onFailure(e))); + } else { + adStatsResponse.setClusterStats(getClusterStatsMap(statsRequest)); + listener.onResponse(adStatsResponse); + } + } + + /** + * Make async request to get the forecasting statistics from each node and, onResponse, set the + * StatsNodesResponse field of StatsResponse + * + * @param client Client + * @param listener MultiResponsesDelegateActionListener to be used once both requests complete + * @param statsRequest Request containing stats to be retrieved + */ + @Override + public void getNodeStats(Client client, MultiResponsesDelegateActionListener listener, StatsRequest statsRequest) { + client.execute(ForecastStatsNodesAction.INSTANCE, statsRequest, ActionListener.wrap(adStatsResponse -> { + StatsResponse restStatsResponse = new StatsResponse(); + restStatsResponse.setStatsNodesResponse(adStatsResponse); + listener.onResponse(restStatsResponse); + }, listener::onFailure)); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java new file mode 100644 index 000000000..9b38db2eb --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.StopConfigResponse; + +public class StopForecasterAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "forecaster/stop"; + public static final StopForecasterAction INSTANCE = new StopForecasterAction(); + + private StopForecasterAction() { + super(NAME, StopConfigResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java new file mode 100644 index 000000000..a3a35e3f2 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_STOP_FORECASTER; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +public class StopForecasterTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(StopForecasterTransportAction.class); + + private final Client client; + private final DiscoveryNodeFilterer nodeFilter; + + @Inject + public StopForecasterTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + Client client + ) { + super(StopForecasterAction.NAME, transportService, actionFilters, StopConfigRequest::new); + this.client = client; + this.nodeFilter = nodeFilter; + } + + @Override + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + StopConfigRequest request = StopConfigRequest.fromActionRequest(actionRequest); + String configId = request.getConfigID(); + try { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + DeleteModelRequest modelDeleteRequest = new DeleteModelRequest(configId, dataNodes); + client.execute(DeleteForecastModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + LOG.warn("Cannot delete all models of forecaster {}", configId); + for (FailedNodeException failedNodeException : response.failures()) { + LOG.warn("Deleting models of node has exception", failedNodeException); + } + // if customers are using an updated detector and we haven't deleted old + // checkpoints, customer would have trouble + listener.onResponse(new StopConfigResponse(false)); + } else { + LOG.info("models of forecaster {} get deleted", configId); + listener.onResponse(new StopConfigResponse(true)); + } + }, exception -> { + LOG.error(new ParameterizedMessage("Deletion of forecaster [{}] has exception.", configId), exception); + listener.onResponse(new StopConfigResponse(false)); + })); + } catch (Exception e) { + LOG.error(FAIL_TO_STOP_FORECASTER + " " + configId, e); + Throwable cause = ExceptionsHelper.unwrapCause(e); + listener.onFailure(new InternalFailure(configId, FAIL_TO_STOP_FORECASTER, cause)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamAction.java b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamAction.java new file mode 100644 index 000000000..bbcee1b72 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.SuggestConfigParamResponse; + +public class SuggestForecasterParamAction extends ActionType { + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/suggest"; + public static final SuggestForecasterParamAction INSTANCE = new SuggestForecasterParamAction(); + + private SuggestForecasterParamAction() { + super(NAME, SuggestConfigParamResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamTransportAction.java b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamTransportAction.java new file mode 100644 index 000000000..d96fe1de5 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/SuggestForecasterParamTransportAction.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.transport.BaseSuggestConfigParamTransportAction; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class SuggestForecasterParamTransportAction extends BaseSuggestConfigParamTransportAction { + public static final Logger logger = LogManager.getLogger(SuggestForecasterParamTransportAction.class); + + @Inject + public SuggestForecasterParamTransportAction( + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ForecastIndexManagement anomalyDetectionIndices, + ActionFilters actionFilters, + TransportService transportService, + SearchFeatureDao searchFeatureDao + ) { + super( + SuggestForecasterParamAction.NAME, + client, + clientUtil, + clusterService, + settings, + actionFilters, + transportService, + FORECAST_FILTER_BY_BACKEND_ROLES, + AnalysisType.FORECAST, + searchFeatureDao + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ValidateForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterAction.java new file mode 100644 index 000000000..26cf17666 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterAction.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.ValidateConfigResponse; + +public class ValidateForecasterAction extends ActionType { + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/validate"; + public static final ValidateForecasterAction INSTANCE = new ValidateForecasterAction(); + + private ValidateForecasterAction() { + super(NAME, ValidateConfigResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java new file mode 100644 index 000000000..01b38ffa6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ValidateForecasterTransportAction.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.rest.handler.ValidateForecasterActionHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.rest.handler.Processor; +import org.opensearch.timeseries.transport.BaseValidateConfigTransportAction; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ValidateForecasterTransportAction extends BaseValidateConfigTransportAction { + public static final Logger logger = LogManager.getLogger(ValidateForecasterTransportAction.class); + + @Inject + public ValidateForecasterTransportAction( + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Settings settings, + ForecastIndexManagement anomalyDetectionIndices, + ActionFilters actionFilters, + TransportService transportService, + SearchFeatureDao searchFeatureDao + ) { + super( + ValidateForecasterAction.NAME, + client, + clientUtil, + clusterService, + xContentRegistry, + settings, + anomalyDetectionIndices, + actionFilters, + transportService, + searchFeatureDao, + FORECAST_FILTER_BY_BACKEND_ROLES + ); + } + + @Override + protected Processor createProcessor(Config forecaster, ValidateConfigRequest request, User user) { + return new ValidateForecasterActionHandler( + clusterService, + client, + clientUtil, + indexManagement, + forecaster, + request.getRequestTimeout(), + request.getMaxSingleEntityAnomalyDetectors(), + request.getMaxMultiEntityAnomalyDetectors(), + request.getMaxAnomalyFeatures(), + request.getMaxCategoricalFields(), + RestRequest.Method.POST, + xContentRegistry, + user, + searchFeatureDao, + request.getValidationType(), + clock, + settings + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..1f94257e3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.transport.ForecastResultBulkAction; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; + +public class ForecastIndexMemoryPressureAwareResultHandler extends + IndexMemoryPressureAwareResultHandler { + private static final Logger LOG = LogManager.getLogger(ForecastIndexMemoryPressureAwareResultHandler.class); + + @Inject + public ForecastIndexMemoryPressureAwareResultHandler(Client client, ForecastIndexManagement anomalyDetectionIndices) { + super(client, anomalyDetectionIndices); + } + + @Override + public void bulk(ForecastResultBulkRequest currentBulkRequest, ActionListener listener) { + if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new TimeSeriesException("no result to save")); + return; + } + client.execute(ForecastResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/handler/ForecastSearchHandler.java b/src/main/java/org/opensearch/forecast/transport/handler/ForecastSearchHandler.java new file mode 100644 index 000000000..61979f534 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/handler/ForecastSearchHandler.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport.handler; + +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.timeseries.transport.handler.SearchHandler; + +/** + * Handle general search request, check user role and return search response. + */ +public class ForecastSearchHandler extends SearchHandler { + + public ForecastSearchHandler(Settings settings, ClusterService clusterService, Client client) { + super(settings, clusterService, client, ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES); + } +} diff --git a/src/main/java/org/opensearch/ad/AbstractProfileRunner.java b/src/main/java/org/opensearch/timeseries/AbstractProfileRunner.java similarity index 92% rename from src/main/java/org/opensearch/ad/AbstractProfileRunner.java rename to src/main/java/org/opensearch/timeseries/AbstractProfileRunner.java index e402a4da1..79345db34 100644 --- a/src/main/java/org/opensearch/ad/AbstractProfileRunner.java +++ b/src/main/java/org/opensearch/timeseries/AbstractProfileRunner.java @@ -9,11 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.util.Locale; -import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.timeseries.model.InitProgressProfile; public abstract class AbstractProfileRunner { protected long requiredSamples; diff --git a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java b/src/main/java/org/opensearch/timeseries/AbstractSearchAction.java similarity index 73% rename from src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java rename to src/main/java/org/opensearch/timeseries/AbstractSearchAction.java index 1d0611cf7..43681e78f 100644 --- a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java +++ b/src/main/java/org/opensearch/timeseries/AbstractSearchAction.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.rest; +package org.opensearch.timeseries; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.timeseries.util.RestHandlerUtils.getSourceContext; @@ -17,6 +17,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; @@ -24,8 +25,6 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.client.node.NodeClient; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; @@ -36,6 +35,7 @@ import org.opensearch.rest.RestResponse; import org.opensearch.rest.action.RestResponseListener; import org.opensearch.search.builder.SearchSourceBuilder; +import org.owasp.encoder.Encode; /** * Abstract class to handle search request. @@ -47,6 +47,8 @@ public abstract class AbstractSearchAction extends B protected final List urlPaths; protected final List> deprecatedPaths; protected final ActionType actionType; + protected final Supplier adEnabledSupplier; + protected final String disabledMsg; private final Logger logger = LogManager.getLogger(AbstractSearchAction.class); @@ -55,29 +57,38 @@ public AbstractSearchAction( List> deprecatedPaths, String index, Class clazz, - ActionType actionType + ActionType actionType, + Supplier adEnabledSupplier, + String disabledMsg ) { this.index = index; this.clazz = clazz; this.urlPaths = urlPaths; this.deprecatedPaths = deprecatedPaths; this.actionType = actionType; + this.adEnabledSupplier = adEnabledSupplier; + this.disabledMsg = disabledMsg; } @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - if (!ADEnabledSetting.isADEnabled()) { - throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + if (!adEnabledSupplier.get()) { + throw new IllegalStateException(disabledMsg); } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - // order of response will be re-arranged everytime we use `_source`, we sometimes do this - // even if user doesn't give this field as we exclude ui_metadata if request isn't from OSD - // ref-link: https://github.com/elastic/elasticsearch/issues/17639 - searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); - searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); - SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); - return channel -> client.execute(actionType, searchRequest, search(channel)); + try { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + // order of response will be re-arranged everytime we use `_source`, we sometimes do this + // even if user doesn't give this field as we exclude ui_metadata if request isn't from OSD + // ref-link: https://github.com/elastic/elasticsearch/issues/17639 + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); + return channel -> client.execute(actionType, searchRequest, search(channel)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException(Encode.forHtml(e.getMessage())); + } + } protected void onFailure(RestChannel channel, Exception e) { diff --git a/src/main/java/org/opensearch/ad/DetectorModelSize.java b/src/main/java/org/opensearch/timeseries/AnalysisModelSize.java similarity index 74% rename from src/main/java/org/opensearch/ad/DetectorModelSize.java rename to src/main/java/org/opensearch/timeseries/AnalysisModelSize.java index 52e4660e6..5e70c456c 100644 --- a/src/main/java/org/opensearch/ad/DetectorModelSize.java +++ b/src/main/java/org/opensearch/timeseries/AnalysisModelSize.java @@ -9,16 +9,16 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.util.Map; -public interface DetectorModelSize { +public interface AnalysisModelSize { /** * Gets all of a detector's model sizes hosted on a node * - * @param detectorId Detector Id + * @param id Analysis Id * @return a map of model id to its memory size */ - Map getModelSize(String detectorId); + Map getModelSize(String id); } diff --git a/src/main/java/org/opensearch/timeseries/AnalysisType.java b/src/main/java/org/opensearch/timeseries/AnalysisType.java index 7d7cc805e..f0f4e2025 100644 --- a/src/main/java/org/opensearch/timeseries/AnalysisType.java +++ b/src/main/java/org/opensearch/timeseries/AnalysisType.java @@ -7,5 +7,13 @@ public enum AnalysisType { AD, - FORECAST + FORECAST; + + public boolean isForecast() { + return this == FORECAST; + } + + public boolean isAD() { + return this == AD; + } } diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java b/src/main/java/org/opensearch/timeseries/EntityProfileRunner.java similarity index 76% rename from src/main/java/org/opensearch/ad/EntityProfileRunner.java rename to src/main/java/org/opensearch/timeseries/EntityProfileRunner.java index 3fc04fe96..43dbe3cbc 100644 --- a/src/main/java/org/opensearch/ad/EntityProfileRunner.java +++ b/src/main/java/org/opensearch/timeseries/EntityProfileRunner.java @@ -9,10 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Optional; @@ -20,27 +21,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionType; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.model.EntityProfile; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.EntityState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.settings.ADNumericSetting; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileRequest; -import org.opensearch.ad.transport.EntityProfileResponse; import org.opensearch.client.Client; import org.opensearch.cluster.routing.Preference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; @@ -50,52 +41,82 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.EntityState; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.EntityProfileRequest; +import org.opensearch.timeseries.transport.EntityProfileResponse; import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.SecurityClientUtil; -public class EntityProfileRunner extends AbstractProfileRunner { +public class EntityProfileRunner> extends AbstractProfileRunner { private final Logger logger = LogManager.getLogger(EntityProfileRunner.class); - static final String NOT_HC_DETECTOR_ERR_MSG = "This is not a high cardinality detector"; + public static final String NOT_HC_DETECTOR_ERR_MSG = "This is not a high cardinality detector"; static final String EMPTY_ENTITY_ATTRIBUTES = "Empty entity attributes"; static final String NO_ENTITY = "Cannot find entity"; private Client client; private SecurityClientUtil clientUtil; private NamedXContentRegistry xContentRegistry; - - public EntityProfileRunner(Client client, SecurityClientUtil clientUtil, NamedXContentRegistry xContentRegistry, long requiredSamples) { + private BiCheckedFunction configParser; + private int maxCategoryFields; + private AnalysisType analysisType; + private EntityProfileActionType entityProfileAction; + private String resultIndexAlias; + private String configIdField; + + public EntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples, + BiCheckedFunction configParser, + int maxCategoryFields, + AnalysisType analysisType, + EntityProfileActionType entityProfileAction, + String resultIndexAlias, + String configIdField + ) { super(requiredSamples); this.client = client; this.clientUtil = clientUtil; this.xContentRegistry = xContentRegistry; + this.configParser = configParser; + this.maxCategoryFields = maxCategoryFields; + this.analysisType = analysisType; + this.entityProfileAction = entityProfileAction; + this.resultIndexAlias = resultIndexAlias; + this.configIdField = configIdField; } /** * Get profile info of specific entity. * - * @param detectorId detector identifier + * @param configId config identifier * @param entityValue entity value * @param profilesToCollect profiles to collect * @param listener action listener to handle exception and process entity profile response */ public void profile( - String detectorId, + String configId, Entity entityValue, Set profilesToCollect, ActionListener listener ) { if (profilesToCollect == null || profilesToCollect.size() == 0) { - listener.onFailure(new IllegalArgumentException(ADCommonMessages.EMPTY_PROFILES_COLLECT)); + listener.onFailure(new IllegalArgumentException(CommonMessages.EMPTY_PROFILES_COLLECT)); return; } - GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); + GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, configId); client.get(getDetectorRequest, ActionListener.wrap(getResponse -> { if (getResponse != null && getResponse.isExists()) { @@ -105,21 +126,20 @@ public void profile( .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId); - List categoryFields = detector.getCategoryFields(); - int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); + Config config = configParser.apply(parser, configId); + List categoryFields = config.getCategoryFields(); if (categoryFields == null || categoryFields.size() == 0) { listener.onFailure(new IllegalArgumentException(NOT_HC_DETECTOR_ERR_MSG)); } else if (categoryFields.size() > maxCategoryFields) { listener.onFailure(new IllegalArgumentException(CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields))); } else { - validateEntity(entityValue, categoryFields, detectorId, profilesToCollect, detector, listener); + validateEntity(entityValue, categoryFields, configId, profilesToCollect, config, listener); } } catch (Exception t) { listener.onFailure(t); } } else { - listener.onFailure(new IllegalArgumentException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + detectorId)); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); } }, listener::onFailure)); } @@ -144,14 +164,18 @@ private void validateEntity( List categoryFields, String detectorId, Set profilesToCollect, - AnomalyDetector detector, + Config config, ActionListener listener ) { Map attributes = entity.getAttributes(); - if (attributes == null || attributes.size() != categoryFields.size()) { + if (attributes == null) { listener.onFailure(new IllegalArgumentException(EMPTY_ENTITY_ATTRIBUTES)); return; } + if (attributes.size() != categoryFields.size()) { + listener.onFailure(new IllegalArgumentException(NO_ENTITY)); + return; + } for (String field : categoryFields) { if (false == attributes.containsKey(field)) { listener.onFailure(new IllegalArgumentException("Cannot find " + field)); @@ -159,15 +183,15 @@ private void validateEntity( } } - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(config.getFilterQuery()); - for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + for (TermQueryBuilder term : entity.getTermQueryForCustomerIndex()) { internalFilterQuery.filter(term); } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery).size(1); - SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder) + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0]), searchSourceBuilder) .preference(Preference.LOCAL.toString()); final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { try { @@ -175,7 +199,7 @@ private void validateEntity( listener.onFailure(new IllegalArgumentException(NO_ENTITY)); return; } - prepareEntityProfile(listener, detectorId, entity, profilesToCollect, detector, categoryFields.get(0)); + prepareEntityProfile(listener, detectorId, entity, profilesToCollect, config, categoryFields.get(0)); } catch (Exception e) { listener.onFailure(new IllegalArgumentException(NO_ENTITY)); return; @@ -187,9 +211,9 @@ private void validateEntity( .asyncRequestWithInjectedSecurity( searchRequest, client::search, - detector.getId(), + config.getId(), client, - AnalysisType.AD, + analysisType, searchResponseListener ); @@ -200,16 +224,16 @@ private void prepareEntityProfile( String detectorId, Entity entityValue, Set profilesToCollect, - AnomalyDetector detector, + Config config, String categoryField ) { EntityProfileRequest request = new EntityProfileRequest(detectorId, entityValue, profilesToCollect); client .execute( - EntityProfileAction.INSTANCE, + entityProfileAction, request, - ActionListener.wrap(r -> getJob(detectorId, entityValue, profilesToCollect, detector, r, listener), listener::onFailure) + ActionListener.wrap(r -> getJob(detectorId, entityValue, profilesToCollect, config, r, listener), listener::onFailure) ); } @@ -217,7 +241,7 @@ private void getJob( String detectorId, Entity entityValue, Set profilesToCollect, - AnomalyDetector detector, + Config config, EntityProfileResponse entityProfileResponse, ActionListener listener ) { @@ -247,7 +271,7 @@ private void getJob( new MultiResponsesDelegateActionListener( listener, totalResponsesToWait, - ADCommonMessages.FAIL_FETCH_ERR_MSG + entityValue + " of detector " + detectorId, + CommonMessages.FAIL_FETCH_ERR_MSG + entityValue + " of detector " + detectorId, false ); @@ -267,7 +291,7 @@ private void getJob( detectorId, entityValue, profilesToCollect, - detector, + config, job, delegateListener ); @@ -279,7 +303,7 @@ private void getJob( detectorId, enabledTimeMs, entityValue, - detector.getCustomResultIndex() + config.getCustomResultIndex() ); EntityProfile.Builder builder = new EntityProfile.Builder(); @@ -310,7 +334,7 @@ private void getJob( })); } } catch (Exception e) { - logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG, e); + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG, e); listener.onFailure(e); } } else { @@ -321,7 +345,7 @@ private void getJob( logger.info(exception.getMessage()); sendUnknownState(profilesToCollect, entityValue, true, listener); } else { - logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG + detectorId, exception); + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG + detectorId, exception); listener.onFailure(exception); } })); @@ -332,7 +356,7 @@ private void profileStateRelated( String detectorId, Entity entityValue, Set profilesToCollect, - AnomalyDetector detector, + Config config, Job job, MultiResponsesDelegateActionListener delegateListener ) { @@ -343,7 +367,7 @@ private void profileStateRelated( } else if (totalUpdates >= requiredSamples) { sendRunningState(profilesToCollect, entityValue, delegateListener); } else { - sendInitState(profilesToCollect, entityValue, detector, totalUpdates, delegateListener); + sendInitState(profilesToCollect, entityValue, config, totalUpdates, delegateListener); } } @@ -390,7 +414,7 @@ private void sendRunningState( private void sendInitState( Set profilesToCollect, Entity entityValue, - AnomalyDetector detector, + Config config, long updates, MultiResponsesDelegateActionListener delegateListener ) { @@ -399,64 +423,21 @@ private void sendInitState( builder.state(EntityState.INIT); } if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS)) { - long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes(); + long intervalMins = ((IntervalTimeConfiguration) config.getInterval()).toDuration().toMinutes(); InitProgressProfile initProgress = computeInitProgressProfile(updates, intervalMins); builder.initProgress(initProgress); } delegateListener.onResponse(builder.build()); } - private SearchRequest createLastSampleTimeRequest(String detectorId, long enabledTime, Entity entity, String resultIndex) { + private SearchRequest createLastSampleTimeRequest(String configId, long enabledTime, Entity entity, String resultIndex) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - String path = "entity"; - String entityName = path + ".name"; - String entityValue = path + ".value"; - - for (Map.Entry attribute : entity.getAttributes().entrySet()) { - /* - * each attribute pair corresponds to a nested query like - "nested": { - "query": { - "bool": { - "filter": [ - { - "term": { - "entity.name": { - "value": "turkey4", - "boost": 1 - } - } - }, - { - "term": { - "entity.value": { - "value": "Turkey", - "boost": 1 - } - } - } - ] - } - }, - "path": "entity", - "ignore_unmapped": false, - "score_mode": "none", - "boost": 1 - } - },*/ - BoolQueryBuilder nestedBoolQueryBuilder = new BoolQueryBuilder(); - - TermQueryBuilder entityNameFilterQuery = QueryBuilders.termQuery(entityName, attribute.getKey()); - nestedBoolQueryBuilder.filter(entityNameFilterQuery); - TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValue, attribute.getValue()); - nestedBoolQueryBuilder.filter(entityValueFilterQuery); - - NestedQueryBuilder nestedNameQueryBuilder = new NestedQueryBuilder(path, nestedBoolQueryBuilder, ScoreMode.None); + for (NestedQueryBuilder nestedNameQueryBuilder : entity.getTermQueryForResultIndex()) { boolQueryBuilder.filter(nestedNameQueryBuilder); } - boolQueryBuilder.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + boolQueryBuilder.filter(QueryBuilders.termQuery(configIdField, configId)); boolQueryBuilder.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); @@ -466,7 +447,7 @@ private SearchRequest createLastSampleTimeRequest(String detectorId, long enable .trackTotalHits(false) .size(0); - SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + SearchRequest request = new SearchRequest(resultIndexAlias); request.source(source); if (resultIndex != null) { request.indices(resultIndex); diff --git a/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java new file mode 100644 index 000000000..3b19684a7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java @@ -0,0 +1,366 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import java.time.Instant; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionType; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHits; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ExceptionUtil; + +public abstract class ExecuteResultResponseRecorder & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType> { + + private static final Logger log = LogManager.getLogger(ExecuteResultResponseRecorder.class); + + protected IndexManagementType indexManagement; + private ResultBulkIndexingHandler resultHandler; + protected TaskManagerType taskManager; + private DiscoveryNodeFilterer nodeFilter; + private ThreadPool threadPool; + private String threadPoolName; + private Client client; + private NodeStateManager nodeStateManager; + private TaskCacheManager taskCacheManager; + private int rcfMinSamples; + protected IndexType resultIndex; + private AnalysisType analysisType; + private ProfileActionType profileAction; + + public ExecuteResultResponseRecorder( + IndexManagementType indexManagement, + ResultBulkIndexingHandler resultHandler, + TaskManagerType taskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + String threadPoolName, + Client client, + NodeStateManager nodeStateManager, + TaskCacheManager taskCacheManager, + int rcfMinSamples, + IndexType resultIndex, + AnalysisType analysisType, + ProfileActionType profileAction + ) { + this.indexManagement = indexManagement; + this.resultHandler = resultHandler; + this.taskManager = taskManager; + this.nodeFilter = nodeFilter; + this.threadPool = threadPool; + this.threadPoolName = threadPoolName; + this.client = client; + this.nodeStateManager = nodeStateManager; + this.taskCacheManager = taskCacheManager; + this.rcfMinSamples = rcfMinSamples; + this.resultIndex = resultIndex; + this.analysisType = analysisType; + this.profileAction = profileAction; + } + + public void indexResult( + Instant detectionStartTime, + Instant executionStartTime, + ResultResponse response, + Config config + ) { + String configId = config.getId(); + try { + + if (!response.shouldSave()) { + updateRealtimeTask(response, configId); + return; + } + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) config.getWindowDelay(); + Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = config.getUser(); + + if (response.getError() != null) { + log.info("Result action run successfully for {} with error {}", configId, response.getError()); + } + + List analysisResults = response + .toIndexableResults( + configId, + dataStartTime, + dataEndTime, + executionStartTime, + Instant.now(), + indexManagement.getSchemaVersion(resultIndex), + user, + response.getError() + ); + + String resultIndex = config.getCustomResultIndex(); + resultHandler + .bulk( + resultIndex, + analysisResults, + configId, + ActionListener + .wrap( + r -> {}, + exception -> log.error(String.format(Locale.ROOT, "Fail to bulk for %s", configId), exception) + ) + ); + updateRealtimeTask(response, configId); + } catch (EndRunException e) { + throw e; + } catch (Exception e) { + log.error("Failed to index result for " + configId, e); + } + } + + /** + * + * If result action is handled asynchronously, response won't contain the result. + * This function wait some time before fetching update. + * One side-effect is if the config is already deleted the latest task will get deleted too. + * This delayed update can cause ResourceNotFoundException. + * + * @param response response returned from executing AnomalyResultAction + * @param configId config Id + */ + protected void delayedUpdate(ResultResponse response, String configId) { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + Set profiles = new HashSet<>(); + profiles.add(ProfileName.INIT_PROGRESS); + ProfileRequest profileRequest = new ProfileRequest(configId, profiles, true, dataNodes); + Runnable profileHCInitProgress = () -> { + client.execute(profileAction, profileRequest, ActionListener.wrap(r -> { + log.debug("Update latest realtime task for config {}, total updates: {}", configId, r.getTotalUpdates()); + updateLatestRealtimeTask(configId, null, r.getTotalUpdates(), response.getConfigIntervalInMinutes(), response.getError()); + }, e -> { log.error("Failed to update latest realtime task for " + configId, e); })); + }; + if (!taskManager.isHCRealtimeTaskStartInitializing(configId)) { + // real time init progress is 0 may mean this is a newly started detector + // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to + // finish. We can change the detector running immediately instead of waiting for the next interval. + threadPool.schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), threadPoolName); + } else { + profileHCInitProgress.run(); + } + } + + protected void updateLatestRealtimeTask( + String configId, + String taskState, + Long rcfTotalUpdates, + Long configIntervalInMinutes, + String error + ) { + // Don't need info as this will be printed repeatedly in each interval + ActionListener listener = ActionListener.wrap(r -> { + if (r != null) { + log.debug("Updated latest realtime task successfully for config {}, taskState: {}", configId, taskState); + } + }, e -> { + if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CommonMessages.CAN_NOT_FIND_LATEST_TASK)) { + // Clear realtime task cache, will recreate task in next run, check ADResultProcessor. + log.error("Can't find latest realtime task of config " + configId); + taskManager.removeRealtimeTaskCache(configId); + } else { + log.error("Failed to update latest realtime task for config " + configId, e); + } + }); + + // rcfTotalUpdates is null when we save exception messages + if (!taskCacheManager.hasQueriedResultIndex(configId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { + // confirm the total updates number since it is possible that we have already had results after job enabling time + // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. + confirmTotalRCFUpdatesFound( + configId, + taskState, + rcfTotalUpdates, + configIntervalInMinutes, + error, + ActionListener + .wrap( + r -> taskManager + .updateLatestRealtimeTaskOnCoordinatingNode(configId, taskState, r, configIntervalInMinutes, error, listener), + e -> { + log.error("Fail to confirm rcf update", e); + taskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + configId, + taskState, + rcfTotalUpdates, + configIntervalInMinutes, + error, + listener + ); + } + ) + ); + } else { + taskManager + .updateLatestRealtimeTaskOnCoordinatingNode(configId, taskState, rcfTotalUpdates, configIntervalInMinutes, error, listener); + } + } + + /** + * The function is not only indexing the result with the exception, but also updating the task state after + * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. + * + * @param executeStartTime execution start time + * @param executeEndTime execution end time + * @param errorMessage Error message to record + * @param taskState task state (e.g., stopped) + * @param config config accessor + */ + public void indexResultException( + Instant executeStartTime, + Instant executeEndTime, + String errorMessage, + String taskState, + Config config + ) { + String configId = config.getId(); + try { + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) config.getWindowDelay(); + Instant dataStartTime = executeStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executeEndTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = config.getUser(); + + IndexableResultType resultToSave = createErrorResult(configId, dataStartTime, dataEndTime, executeEndTime, errorMessage, user); + String resultIndex = config.getCustomResultIndex(); + if (resultIndex != null && !indexManagement.doesIndexExist(resultIndex)) { + // Set result index as null, will write exception to default result index. + resultHandler.index(resultToSave, configId, null); + } else { + resultHandler.index(resultToSave, configId, resultIndex); + } + + if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !config.isHighCardinality()) { + // single stream detector raises ResourceNotFoundException containing ADCommonMessages.NO_CHECKPOINT_ERR_MSG + // when there is no checkpoint. + // Delay real time cache update by one minute so we will have trained models by then and update the state + // document accordingly. + threadPool.schedule(() -> { + RCFPollingRequest request = new RCFPollingRequest(configId); + client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { + long totalUpdates = rcfPollResponse.getTotalUpdates(); + // if there are updates, don't record failures + updateLatestRealtimeTask( + configId, + taskState, + totalUpdates, + config.getIntervalInMinutes(), + totalUpdates > 0 ? "" : errorMessage + ); + }, e -> { + log.error("Fail to execute RCFRollingAction", e); + updateLatestRealtimeTask(configId, taskState, null, null, errorMessage); + })); + }, new TimeValue(60, TimeUnit.SECONDS), threadPoolName); + } else { + updateLatestRealtimeTask(configId, taskState, null, null, errorMessage); + } + + } catch (Exception e) { + log.error("Failed to index anomaly result for " + configId, e); + } + } + + private void confirmTotalRCFUpdatesFound( + String configId, + String taskState, + Long rcfTotalUpdates, + Long configIntervalInMinutes, + String error, + ActionListener listener + ) { + nodeStateManager.getConfig(configId, analysisType, ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(configId, "fail to get config")); + return; + } + nodeStateManager.getJob(configId, ActionListener.wrap(jobOptional -> { + if (!jobOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(configId, "fail to get job")); + return; + } + + ProfileUtil + .confirmRealtimeInitStatus( + configOptional.get(), + jobOptional.get().getEnabledTime().toEpochMilli(), + client, + analysisType, + ActionListener.wrap(searchResponse -> { + ActionListener.completeWith(listener, () -> { + SearchHits hits = searchResponse.getHits(); + Long correctedTotalUpdates = rcfTotalUpdates; + if (hits.getTotalHits().value > 0L) { + // correct the number if we have already had results after job enabling time + // so that the detector won't stay initialized + correctedTotalUpdates = Long.valueOf(rcfMinSamples); + } + taskCacheManager.markResultIndexQueried(configId); + return correctedTotalUpdates; + }); + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + taskCacheManager.markResultIndexQueried(configId); + listener.onResponse(0L); + } else { + listener.onFailure(exception); + } + }) + ); + }, e -> listener.onFailure(new TimeSeriesException(configId, "fail to get job")))); + }, e -> listener.onFailure(new TimeSeriesException(configId, "fail to get config")))); + } + + protected abstract IndexableResultType createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user + ); + + // protected abstract void updateRealtimeTask(ResultResponseType response, String configId); + protected abstract void updateRealtimeTask(ResultResponse response, String configId); +} diff --git a/src/main/java/org/opensearch/timeseries/JobProcessor.java b/src/main/java/org/opensearch/timeseries/JobProcessor.java new file mode 100644 index 000000000..8ce5e861b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/JobProcessor.java @@ -0,0 +1,583 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import java.time.Instant; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionType; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.InjectSecurity; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.util.SecurityUtil; + +import com.google.common.base.Throwables; + +/** + * JobScheduler will call job runner to get time series analysis result periodically + */ +public abstract class JobProcessor & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder, IndexJobActionHandlerType extends IndexJobActionHandler> { + + private static final Logger log = LogManager.getLogger(JobProcessor.class); + + private Settings settings; + private int maxRetryForEndRunException; + private Client client; + private ThreadPool threadPool; + private ConcurrentHashMap endRunExceptionCount; + protected IndexManagementType indexManagement; + private TaskManagerType taskManager; + private NodeStateManager nodeStateManager; + private ExecuteResultResponseRecorderType recorder; + private AnalysisType analysisType; + private String threadPoolName; + private ActionType> resultAction; + private IndexJobActionHandlerType indexJobActionHandler; + + protected JobProcessor( + AnalysisType analysisType, + String threadPoolName, + ActionType> resultAction + ) { + // Singleton class, use getJobRunnerInstance method instead of constructor + this.endRunExceptionCount = new ConcurrentHashMap<>(); + this.analysisType = analysisType; + this.threadPoolName = threadPoolName; + this.resultAction = resultAction; + } + + public void setClient(Client client) { + this.client = client; + } + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + protected void registerSettings(Settings settings, Setting maxRetryForEndRunExceptionSetting) { + this.settings = settings; + this.maxRetryForEndRunException = maxRetryForEndRunExceptionSetting.get(settings); + } + + public void setTaskManager(TaskManagerType adTaskManager) { + this.taskManager = adTaskManager; + } + + public void setIndexManagement(IndexManagementType anomalyDetectionIndices) { + this.indexManagement = anomalyDetectionIndices; + } + + public void setNodeStateManager(NodeStateManager nodeStateManager) { + this.nodeStateManager = nodeStateManager; + } + + public void setExecuteResultResponseRecorder(ExecuteResultResponseRecorderType recorder) { + this.recorder = recorder; + } + + public void setIndexJobActionHandler(IndexJobActionHandlerType indexJobActionHandler) { + this.indexJobActionHandler = indexJobActionHandler; + } + + public void process(Job jobParameter, JobExecutionContext context) { + String configId = jobParameter.getName(); + + log.info("Start to run {} job {}", analysisType, configId); + + taskManager.refreshRealtimeJobRunTime(configId); + + Instant executionEndTime = Instant.now(); + IntervalSchedule schedule = (IntervalSchedule) jobParameter.getSchedule(); + Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); + + final LockService lockService = context.getLockService(); + + Runnable runnable = () -> { + try { + nodeStateManager.getConfig(configId, analysisType, ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + log.error(new ParameterizedMessage("fail to get config [{}]", configId)); + return; + } + Config config = configOptional.get(); + + if (jobParameter.getLockDurationSeconds() != null) { + lockService + .acquireLock( + jobParameter, + context, + ActionListener + .wrap( + lock -> runJob( + jobParameter, + lockService, + lock, + executionStartTime, + executionEndTime, + recorder, + config + ), + exception -> { + indexResultException( + jobParameter, + lockService, + null, + executionStartTime, + executionEndTime, + exception, + false, + recorder, + config + ); + throw new IllegalStateException("Failed to acquire lock for job: " + configId); + } + ) + ); + } else { + log.warn("Can't get lock for job: " + configId); + } + + }, e -> log.error(new ParameterizedMessage("fail to get config [{}]", configId), e))); + } catch (Exception e) { + // os log won't show anything if there is an exception happens (maybe due to running on a ExecutorService) + // we at least log the error. + log.error("Can't start job: " + configId, e); + throw e; + } + }; + + ExecutorService executor = threadPool.executor(threadPoolName); + executor.submit(runnable); + } + + /** + * Get analysis result, index result or handle exception if failed. + * + * @param jobParameter scheduled job parameter + * @param lockService lock service + * @param lock lock to run job + * @param executionStartTime analysis start time + * @param executionEndTime analysis end time + * @param recorder utility to record job execution result + * @param detector associated detector accessor + */ + public void runJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + String configId = jobParameter.getName(); + if (lock == null) { + indexResultException( + jobParameter, + lockService, + lock, + executionStartTime, + executionEndTime, + "Can't run job due to null lock", + false, + recorder, + detector + ); + return; + } + indexManagement.update(); + + User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); + + String user = userInfo.getName(); + List roles = userInfo.getRoles(); + + validateResultIndexAndRunJob( + jobParameter, + lockService, + lock, + executionStartTime, + executionEndTime, + configId, + user, + roles, + recorder, + detector + ); + } + + protected abstract void validateResultIndexAndRunJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteResultResponseRecorderType recorder2, + Config detector + ); + + protected void runJob( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + String configId, + String user, + List roles, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + // using one thread in the write threadpool + try (InjectSecurity injectSecurity = new InjectSecurity(configId, settings, client.threadPool().getThreadContext())) { + // Injecting user role to verify if the user has permissions for our API. + injectSecurity.inject(user, roles); + + ResultRequest request = createResultRequest(configId, executionStartTime.toEpochMilli(), executionEndTime.toEpochMilli()); + client.execute(resultAction, request, ActionListener.wrap(response -> { + indexResult(jobParameter, lockService, lock, executionStartTime, executionEndTime, response, recorder, detector); + }, + exception -> { + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, exception, recorder, detector); + } + )); + } catch (Exception e) { + indexResultException(jobParameter, lockService, lock, executionStartTime, executionEndTime, e, true, recorder, detector); + log.error("Failed to execute AD job " + configId, e); + } + } + + /** + * Handle exception from anomaly result action. + * + * 1. If exception is {@link EndRunException} + * a). if isEndNow == true, stop job and store exception in result + * b). if isEndNow == false, record count of {@link EndRunException} for this + * analysis. If count of {@link EndRunException} exceeds upper limit, will + * stop job and store exception in result; otherwise, just + * store exception in result, not stop job for the config. + * + * 2. If exception is not {@link EndRunException}, decrease count of + * {@link EndRunException} for the config and index exception in + * result. If exception is {@link InternalFailure}, will not log exception + * stack trace as already logged in {@link JobProcessor}. + * + * TODO: Handle finer granularity exception such as some exception may be + * transient and retry in current job may succeed. Currently, we don't + * know which exception is transient and retryable in + * {@link JobProcessor}. So we don't add backoff retry + * now to avoid bring extra load to cluster, expecially the code start + * process is relatively heavy by sending out 24 queries, initializing + * models, and saving checkpoints. + * Sometimes missing anomaly and notification is not acceptable. For example, + * current detection interval is 1hour, and there should be anomaly in + * current interval, some transient exception may fail current AD job, + * so no anomaly found and user never know it. Then we start next AD job, + * maybe there is no anomaly in next 1hour, user will never know something + * wrong happened. In one word, this is some tradeoff between protecting + * our performance, user experience and what we can do currently. + * + * @param jobParameter scheduled job parameter + * @param lockService lock service + * @param lock lock to run job + * @param detectionStartTime detection start time + * @param executionStartTime detection end time + * @param exception exception + * @param recorder utility to record job execution result + * @param config associated config accessor + */ + public void handleException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + Exception exception, + ExecuteResultResponseRecorderType recorder, + Config config + ) { + String configId = jobParameter.getName(); + if (exception instanceof EndRunException) { + log.error("EndRunException happened when executing result action for " + configId, exception); + + if (((EndRunException) exception).isEndNow()) { + // Stop AD job if EndRunException shows we should end job now. + log.info("JobRunner will stop job due to EndRunException for {}", configId); + stopJobForEndRunException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + (EndRunException) exception, + recorder, + config + ); + } else { + endRunExceptionCount.compute(configId, (k, v) -> { + if (v == null) { + return 1; + } else { + return v + 1; + } + }); + log.info("EndRunException happened for {}", configId); + // if AD job failed consecutively due to EndRunException and failed times exceeds upper limit, will stop AD job + if (endRunExceptionCount.get(configId) > maxRetryForEndRunException) { + log + .info( + "JobRunner will stop job due to EndRunException retry exceeds upper limit {} for {}", + maxRetryForEndRunException, + configId + ); + stopJobForEndRunException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + (EndRunException) exception, + recorder, + config + ); + return; + } + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception.getMessage(), + true, + recorder, + config + ); + } + } else { + endRunExceptionCount.remove(configId); + if (exception instanceof InternalFailure) { + log.error("InternalFailure happened when executing result action for " + configId, exception); + } else { + log.error("Failed to execute result action for " + configId, exception); + } + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception, + true, + recorder, + config + ); + } + } + + private void stopJobForEndRunException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + EndRunException exception, + ExecuteResultResponseRecorderType recorder, + Config config + ) { + String configId = jobParameter.getName(); + endRunExceptionCount.remove(configId); + String errorPrefix = exception.isEndNow() + ? "Stopped analysis: " + : "Stopped analysis as job failed consecutively for more than " + this.maxRetryForEndRunException + " times: "; + String error = errorPrefix + exception.getMessage(); + + ExecutorFunction runAfer = () -> indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + error, + true, + TaskState.STOPPED.name(), + recorder, + config + ); + + ActionListener stopListener = ActionListener.wrap(jobResponse -> { + log.info("Job was disabled by JobRunner for " + configId); + runAfer.execute(); + }, exp -> { + log.error("JobRunner failed to update job as disabled for " + configId, exp); + runAfer.execute(); + }); + + // transport service is null as we cannot access transport service outside of transport action + // to reset real time job we don't need transport service and we have guarded against the null + // reference in task manager + indexJobActionHandler.stopJob(configId, null, stopListener); + } + + private void indexResult( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant executionStartTime, + Instant executionEndTime, + ResultResponse response, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + String detectorId = jobParameter.getName(); + endRunExceptionCount.remove(detectorId); + try { + recorder.indexResult(executionStartTime, executionEndTime, response, detector); + } catch (EndRunException e) { + handleException(jobParameter, lockService, lock, executionStartTime, executionEndTime, e, recorder, detector); + } catch (Exception e) { + log.error("Failed to index anomaly result for " + detectorId, e); + } finally { + releaseLock(jobParameter, lockService, lock); + } + + } + + private void indexResultException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + Exception exception, + boolean releaseLock, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + try { + String errorMessage = exception instanceof TimeSeriesException + ? exception.getMessage() + : Throwables.getStackTraceAsString(exception); + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + errorMessage, + releaseLock, + recorder, + detector + ); + } catch (Exception e) { + log.error("Failed to index result for " + jobParameter.getName(), e); + } + } + + private void indexResultException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + boolean releaseLock, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + indexResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + errorMessage, + releaseLock, + null, + recorder, + detector + ); + } + + private void indexResultException( + Job jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + boolean releaseLock, + String taskState, + ExecuteResultResponseRecorderType recorder, + Config detector + ) { + try { + recorder.indexResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); + } finally { + if (releaseLock) { + releaseLock(jobParameter, lockService, lock); + } + } + } + + private void releaseLock(Job jobParameter, LockService lockService, LockModel lock) { + lockService + .release( + lock, + ActionListener + .wrap(released -> { log.info("Released lock for {} job {}", analysisType, jobParameter.getName()); }, exception -> { + log + .error( + new ParameterizedMessage("Failed to release lock for [{}] job [{}]", analysisType, jobParameter.getName()), + exception + ); + }) + ); + } + + protected abstract ResultRequest createResultRequest(String configID, long start, long end); +} diff --git a/src/main/java/org/opensearch/timeseries/JobRunner.java b/src/main/java/org/opensearch/timeseries/JobRunner.java new file mode 100644 index 000000000..68a50ee4f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/JobRunner.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import org.opensearch.ad.ADJobProcessor; +import org.opensearch.forecast.ForecastJobProcessor; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.timeseries.model.Job; + +public class JobRunner implements ScheduledJobRunner { + private static JobRunner INSTANCE; + + public static JobRunner getJobRunnerInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (JobRunner.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new JobRunner(); + return INSTANCE; + } + } + + @Override + public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext context) { + if (!(scheduledJobParameter instanceof Job)) { + throw new IllegalArgumentException( + "Job parameter is not instance of Job, type: " + scheduledJobParameter.getClass().getCanonicalName() + ); + } + Job jobParameter = (Job) scheduledJobParameter; + switch (jobParameter.getAnalysisType()) { + case AD: + ADJobProcessor.getInstance().process(jobParameter, context); + break; + case FORECAST: + ForecastJobProcessor.getInstance().process(jobParameter, context); + break; + default: + throw new IllegalArgumentException("Analysis type is not supported, type: : " + jobParameter.getAnalysisType()); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/MaintenanceState.java b/src/main/java/org/opensearch/timeseries/MaintenanceState.java index 07bbb9546..7fc55dc6d 100644 --- a/src/main/java/org/opensearch/timeseries/MaintenanceState.java +++ b/src/main/java/org/opensearch/timeseries/MaintenanceState.java @@ -22,11 +22,11 @@ public interface MaintenanceState { default void maintenance(Map stateToClean, Duration stateTtl) { stateToClean.entrySet().stream().forEach(entry -> { - K detectorId = entry.getKey(); + K configId = entry.getKey(); V state = entry.getValue(); if (state.expired(stateTtl)) { - stateToClean.remove(detectorId); + stateToClean.remove(configId); } }); diff --git a/src/main/java/org/opensearch/timeseries/NodeStateManager.java b/src/main/java/org/opensearch/timeseries/NodeStateManager.java index 799a1b6ca..37f3336f6 100644 --- a/src/main/java/org/opensearch/timeseries/NodeStateManager.java +++ b/src/main/java/org/opensearch/timeseries/NodeStateManager.java @@ -78,6 +78,8 @@ public class NodeStateManager implements MaintenanceState, CleanState, Exception * @param clock A UTC clock * @param stateTtl Max time to keep state in memory * @param clusterService Cluster service accessor + * @param maxRetryForUnresponsiveNodeSetting max retry number for unresponsive node + * @param backoffMinutesSetting back off minutes setting */ public NodeStateManager( Client client, @@ -206,9 +208,9 @@ public void getConfig( ) { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); Config config = null; - if (analysisType == AnalysisType.AD) { + if (analysisType.isAD()) { config = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - } else if (analysisType == AnalysisType.FORECAST) { + } else if (analysisType.isForecast()) { config = Forecaster.parse(parser, response.getId(), response.getVersion()); } else { throw new UnsupportedOperationException("This method is not supported"); @@ -232,7 +234,7 @@ public void getConfig(String configID, AnalysisType context, ActionListener configParser = context == AnalysisType.AD + BiCheckedFunction configParser = context.isAD() ? AnomalyDetector::parse : Forecaster::parse; clientUtil.asyncRequest(request, client::get, onGetConfigResponse(configID, configParser, listener)); diff --git a/src/main/java/org/opensearch/timeseries/ProfileRunner.java b/src/main/java/org/opensearch/timeseries/ProfileRunner.java new file mode 100644 index 000000000..6e486c17b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ProfileRunner.java @@ -0,0 +1,567 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.NOT_FOUND; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalCardinality; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; +import org.opensearch.timeseries.model.InitProgressProfile; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public abstract class ProfileRunner & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskProfileType extends TaskProfile, TaskManagerType extends TaskManager, ConfigProfileType extends ConfigProfile, ProfileActionType extends ActionType, TaskProfileRunnerType extends TaskProfileRunner> + extends AbstractProfileRunner { + private final Logger logger = LogManager.getLogger(ProfileRunner.class); + protected Client client; + protected SecurityClientUtil clientUtil; + protected NamedXContentRegistry xContentRegistry; + protected DiscoveryNodeFilterer nodeFilter; + protected final TransportService transportService; + protected final TaskManagerType taskManager; + protected final int maxTotalEntitiesToTrack; + protected final AnalysisType analysisType; + protected final List realTimeTaskTypes; + protected final List batchConfigTaskTypes; + protected int maxCategoricalFields; + protected ProfileName taskProfile; + protected TaskProfileRunnerType taskProfileRunner; + protected ProfileActionType profileAction; + protected BiCheckedFunction configParser; + + public ProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + TaskManagerType taskManager, + AnalysisType analysisType, + List realTimeTaskTypes, + List batchConfigTaskTypes, + int maxCategoricalFields, + ProfileName taskProfile, + ProfileActionType profileAction, + BiCheckedFunction configParser, + TaskProfileRunnerType taskProfileRunner + ) { + super(requiredSamples); + this.client = client; + this.clientUtil = clientUtil; + this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; + if (requiredSamples <= 0) { + throw new IllegalArgumentException("required samples should be a positive number, but was " + requiredSamples); + } + this.transportService = transportService; + this.taskManager = taskManager; + this.maxTotalEntitiesToTrack = TimeSeriesSettings.MAX_TOTAL_ENTITIES_TO_TRACK; + this.analysisType = analysisType; + this.realTimeTaskTypes = realTimeTaskTypes; + this.batchConfigTaskTypes = batchConfigTaskTypes; + this.maxCategoricalFields = maxCategoricalFields; + this.taskProfile = taskProfile; + this.profileAction = profileAction; + this.configParser = configParser; + this.taskProfileRunner = taskProfileRunner; + } + + public void profile(String configId, ActionListener listener, Set profilesToCollect) { + if (profilesToCollect.isEmpty()) { + listener.onFailure(new IllegalArgumentException(CommonMessages.EMPTY_PROFILES_COLLECT)); + return; + } + calculateTotalResponsesToWait(configId, profilesToCollect, listener); + } + + private void calculateTotalResponsesToWait( + String configId, + Set profilesToCollect, + ActionListener listener + ) { + GetRequest getConfigRequest = new GetRequest(CommonName.CONFIG_INDEX, configId); + client.get(getConfigRequest, ActionListener.wrap(getConfigResponse -> { + if (getConfigResponse != null && getConfigResponse.isExists()) { + try ( + XContentParser xContentParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getConfigResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); + Config config = configParser.apply(xContentParser, configId); + prepareProfile(config, listener, profilesToCollect); + } catch (Exception e) { + logger.error(CommonMessages.FAIL_TO_PARSE_CONFIG_MSG + configId, e); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_PARSE_CONFIG_MSG + configId, BAD_REQUEST)); + } + } else { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, NOT_FOUND)); + } + }, exception -> { + logger.error(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, exception); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, NOT_FOUND)); + })); + } + + protected void prepareProfile(Config config, ActionListener listener, Set profilesToCollect) { + String configId = config.getId(); + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, configId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + long enabledTimeMs = job.getEnabledTime().toEpochMilli(); + + int totalResponsesToWait = 0; + if (profilesToCollect.contains(ProfileName.ERROR)) { + totalResponsesToWait++; + } + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + if (profilesToCollect.contains(ProfileName.TOTAL_ENTITIES)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS) + || profilesToCollect.contains(ProfileName.ACTIVE_ENTITIES) + || profilesToCollect.contains(ProfileName.INIT_PROGRESS) + || profilesToCollect.contains(ProfileName.STATE)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(taskProfile)) { + totalResponsesToWait++; + } + + MultiResponsesDelegateActionListener delegateListener = + new MultiResponsesDelegateActionListener( + listener, + totalResponsesToWait, + CommonMessages.FAIL_FETCH_ERR_MSG + configId, + false + ); + if (profilesToCollect.contains(ProfileName.ERROR)) { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, realTimeTaskTypes, task -> { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + if (task.isPresent()) { + long lastUpdateTimeMs = task.get().getLastUpdateTime().toEpochMilli(); + + // if state index hasn't been updated, we should not use the error field + // For example, before a detector is enabled, if the error message contains + // the phrase "stopped due to blah", we should not show this when the detector + // is enabled. + if (lastUpdateTimeMs > enabledTimeMs && task.get().getError() != null) { + profileBuilder.error(task.get().getError()); + } + delegateListener.onResponse(profileBuilder.build()); + } else { + // detector state for this detector does not exist + delegateListener.onResponse(profileBuilder.build()); + } + }, transportService, false, delegateListener); + } + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + if (profilesToCollect.contains(ProfileName.TOTAL_ENTITIES)) { + profileEntityStats(delegateListener, config); + } + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE) + || profilesToCollect.contains(ProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(ProfileName.MODELS) + || profilesToCollect.contains(ProfileName.ACTIVE_ENTITIES) + || profilesToCollect.contains(ProfileName.INIT_PROGRESS) + || profilesToCollect.contains(ProfileName.STATE)) { + profileModels(config, profilesToCollect, job, true, delegateListener); + } + if (profilesToCollect.contains(taskProfile)) { + getLatestHistoricalTaskProfile(configId, transportService, null, delegateListener); + } + + } catch (Exception e) { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG, e); + listener.onFailure(e); + } + } else { + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + logger.info(exception.getMessage()); + onGetDetectorForPrepare(configId, listener, profilesToCollect); + } else { + logger.error(CommonMessages.FAIL_TO_GET_PROFILE_MSG + configId); + listener.onFailure(exception); + } + })); + } + + private void profileEntityStats(MultiResponsesDelegateActionListener listener, Config config) { + List categoryField = config.getCategoryFields(); + if (!config.isHighCardinality() || categoryField.size() > maxCategoricalFields) { + listener.onResponse(createProfileBuilder().build()); + } else { + if (categoryField.size() == 1) { + // Run a cardinality aggregation to count the cardinality of single category fields + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + CardinalityAggregationBuilder aggBuilder = new CardinalityAggregationBuilder(CommonName.TOTAL_ENTITIES); + aggBuilder.field(categoryField.get(0)); + searchSourceBuilder.aggregation(aggBuilder); + + SearchRequest request = new SearchRequest(config.getIndices().toArray(new String[0]), searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + Map aggMap = searchResponse.getAggregations().asMap(); + InternalCardinality totalEntities = (InternalCardinality) aggMap.get(CommonName.TOTAL_ENTITIES); + long value = totalEntities.getValue(); + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + ConfigProfileType profile = profileBuilder.totalEntities(value).build(); + listener.onResponse(profile); + }, searchException -> { + logger.warn(CommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + config.getId()); + listener.onFailure(searchException); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + config.getId(), + client, + analysisType, + searchResponseListener + ); + } else { + // Run a composite query and count the number of buckets to decide cardinality of multiple category fields + AggregationBuilder bucketAggs = AggregationBuilders + .composite( + CommonName.TOTAL_ENTITIES, + config.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + ) + .size(maxTotalEntitiesToTrack); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(bucketAggs).trackTotalHits(false).size(0); + SearchRequest searchRequest = new SearchRequest() + .indices(config.getIndices().toArray(new String[0])) + .source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + Aggregations aggs = searchResponse.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date + // with + // the large amounts of changes there). For example, they may change to if there are results return it; otherwise + // return + // null instead of an empty Aggregations as they currently do. + logger.warn("Unexpected null aggregation."); + listener.onResponse(profileBuilder.totalEntities(0L).build()); + return; + } + + Aggregation aggrResult = aggs.get(CommonName.TOTAL_ENTITIES); + if (aggrResult == null) { + listener.onFailure(new IllegalArgumentException("Fail to find valid aggregation result")); + return; + } + + CompositeAggregation compositeAgg = (CompositeAggregation) aggrResult; + ConfigProfileType profile = profileBuilder.totalEntities(Long.valueOf(compositeAgg.getBuckets().size())).build(); + listener.onResponse(profile); + }, searchException -> { + logger.warn(CommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + config.getId()); + listener.onFailure(searchException); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + config.getId(), + client, + analysisType, + searchResponseListener + ); + } + + } + } + + protected void onGetDetectorForPrepare(String configId, ActionListener listener, Set profiles) { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + if (profiles.contains(ProfileName.STATE)) { + profileBuilder.state(ConfigState.DISABLED); + } + if (profiles.contains(taskProfile)) { + getLatestHistoricalTaskProfile(configId, transportService, profileBuilder.build(), listener); + } else { + listener.onResponse(profileBuilder.build()); + } + } + + /** + * Profile models related + * + * @param config Config accessor + * @param profiles profiles to collect + * @param job Job accessor + * @param modelInPriorityCache Whether the models are stored in priority cache. AD single stream models are stored in ModelManager. + * Other models are stored in priority cache. + * @param listener returns collected profiles + */ + protected void profileModels( + Config config, + Set profiles, + Job job, + boolean modelInPriorityCache, + MultiResponsesDelegateActionListener listener + ) { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + ProfileRequest profileRequest = new ProfileRequest(config.getId(), profiles, modelInPriorityCache, dataNodes); + client.execute(profileAction, profileRequest, onModelResponse(config, profiles, job, modelInPriorityCache, listener));// get init + // progress + } + + private ActionListener onModelResponse( + Config config, + Set profilesToCollect, + Job job, + boolean modelInPriorityCache, + MultiResponsesDelegateActionListener listener + ) { + boolean isMultientityDetector = config.isHighCardinality(); + return ActionListener.wrap(profileResponse -> { + ConfigProfileType.Builder profile = createProfileBuilder(); + if (profilesToCollect.contains(ProfileName.COORDINATING_NODE)) { + profile.coordinatingNode(profileResponse.getCoordinatingNode()); + } + if (profilesToCollect.contains(ProfileName.SHINGLE_SIZE)) { + profile.shingleSize(profileResponse.getShingleSize()); + } + if (profilesToCollect.contains(ProfileName.TOTAL_SIZE_IN_BYTES)) { + profile.totalSizeInBytes(profileResponse.getTotalSizeInBytes()); + } + if (profilesToCollect.contains(ProfileName.MODELS)) { + profile.modelProfile(profileResponse.getModelProfile()); + profile.modelCount(profileResponse.getModelCount()); + } + if (isMultientityDetector && profilesToCollect.contains(ProfileName.ACTIVE_ENTITIES)) { + profile.activeEntities(profileResponse.getActiveEntities()); + } + + // only need to do it for models in priority cache. AD single stream analysis has a + // different workflow to determine state and init progress + if (modelInPriorityCache + && (profilesToCollect.contains(ProfileName.INIT_PROGRESS) || profilesToCollect.contains(ProfileName.STATE))) { + profileStateRelated(job, profilesToCollect, profileResponse, profile, config, listener); + } else { + listener.onResponse(profile.build()); + } + }, listener::onFailure); + } + + private void profileStateRelated( + Job job, + Set profilesToCollect, + ProfileResponse profileResponse, + ConfigProfileType.Builder profileBuilder, + Config config, + MultiResponsesDelegateActionListener listener + ) { + if (job.isEnabled()) { + if (profileResponse.getTotalUpdates() < requiredSamples) { + // need to double check for an HC analysis + // since what ProfileResponse returns is the highest priority entity currently in memory, but + // another entity might have already been initialized and sit somewhere else (in memory or on disk). + long enabledTime = job.getEnabledTime().toEpochMilli(); + long totalUpdates = profileResponse.getTotalUpdates(); + ProfileUtil + .confirmRealtimeInitStatus( + config, + enabledTime, + client, + analysisType, + onInittedEver(enabledTime, profileBuilder, profilesToCollect, config, totalUpdates, listener) + ); + } else { + createRunningStateAndInitProgress(profilesToCollect, profileBuilder); + listener.onResponse(profileBuilder.build()); + } + } else { + if (profilesToCollect.contains(ProfileName.STATE)) { + profileBuilder.state(ConfigState.DISABLED); + } + listener.onResponse(profileBuilder.build()); + } + } + + private ActionListener onInittedEver( + long lastUpdateTimeMs, + ConfigProfileType.Builder profileBuilder, + Set profilesToCollect, + Config config, + long totalUpdates, + MultiResponsesDelegateActionListener listener + ) { + return ActionListener.wrap(searchResponse -> { + SearchHits hits = searchResponse.getHits(); + if (hits.getTotalHits().value == 0L) { + processInitResponse(config, profilesToCollect, totalUpdates, false, profileBuilder, listener); + } else { + createRunningStateAndInitProgress(profilesToCollect, profileBuilder); + listener.onResponse(profileBuilder.build()); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + processInitResponse(config, profilesToCollect, totalUpdates, false, profileBuilder, listener); + } else { + logger + .error( + "Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}", + config.getId() + ); + listener.onFailure(exception); + } + }); + } + + protected void createRunningStateAndInitProgress( + Set profilesToCollect, + ConfigProfileType.Builder builder + ) { + if (profilesToCollect.contains(ProfileName.STATE)) { + builder.state(ConfigState.RUNNING).build(); + } + + if (profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { + InitProgressProfile initProgress = new InitProgressProfile("100%", 0, 0); + builder.initProgress(initProgress); + } + } + + protected void processInitResponse( + Config config, + Set profilesToCollect, + long totalUpdates, + boolean hideMinutesLeft, + ConfigProfileType.Builder builder, + MultiResponsesDelegateActionListener listener + ) { + if (profilesToCollect.contains(ProfileName.STATE)) { + builder.state(ConfigState.INIT); + } + + if (profilesToCollect.contains(ProfileName.INIT_PROGRESS)) { + if (hideMinutesLeft) { + InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, 0); + builder.initProgress(initProgress); + } else { + long intervalMins = ((IntervalTimeConfiguration) config.getInterval()).toDuration().toMinutes(); + InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, intervalMins); + builder.initProgress(initProgress); + } + } + + listener.onResponse(builder.build()); + } + + /** + * Get latest historical config task profile. + * Will not reset task state in this method. + * + * @param configId config id + * @param transportService transport service + * @param profile config profile + * @param listener action listener + */ + public void getLatestHistoricalTaskProfile( + String configId, + TransportService transportService, + ConfigProfileType profile, + ActionListener listener + ) { + taskManager.getAndExecuteOnLatestConfigTask(configId, null, null, batchConfigTaskTypes, task -> { + if (task.isPresent()) { + taskProfileRunner.getTaskProfile(task.get(), ActionListener.wrap(taskProfile -> { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + profileBuilder.taskProfile(taskProfile); + ConfigProfileType configProfile = profileBuilder.build(); + configProfile.merge(profile); + listener.onResponse(configProfile); + }, e -> { + logger.error("Failed to get task profile for task " + task.get().getTaskId(), e); + listener.onFailure(e); + })); + } else { + ConfigProfileType.Builder profileBuilder = createProfileBuilder(); + listener.onResponse(profileBuilder.build()); + } + }, transportService, false, listener); + } + + protected abstract ConfigProfileType.Builder createProfileBuilder(); + +} diff --git a/src/main/java/org/opensearch/timeseries/ProfileTask.java b/src/main/java/org/opensearch/timeseries/ProfileTask.java new file mode 100644 index 000000000..9b68a2db6 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ProfileTask.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.model.TimeSeriesTask; + +/** + * Break the cross dependency between TaskManager and ProfileRunner. Instead of + * depending on each other, they depend on the interface. + * + */ +public interface ProfileTask> { + void getTaskProfile(TaskClass configLevelTask, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/ProfileUtil.java b/src/main/java/org/opensearch/timeseries/ProfileUtil.java new file mode 100644 index 000000000..b6de04ba7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ProfileUtil.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Config; + +public class ProfileUtil { + /** + * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createADRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); + // Historical analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + /** + * Create search request to check if we have at least 1 forecast after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param forecasterId forecaster id + * @param enabledTime the time when forecast job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createForecastRealtimeInittedEverRequest(String forecasterId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(ForecastCommonName.FORECASTER_ID_KEY, forecasterId)); + filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + ExistsQueryBuilder forecastsExistFilter = QueryBuilders.existsQuery(ForecastResult.VALUE_FIELD); + filterQuery.must(forecastsExistFilter); + // Historical/run-once analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ForecastIndex.RESULT.getIndexName()); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + public static void confirmRealtimeInitStatus( + Config config, + long enabledTime, + Client client, + AnalysisType analysisType, + ActionListener listener + ) { + SearchRequest searchLatestResult = null; + if (analysisType.isAD()) { + searchLatestResult = createADRealtimeInittedEverRequest(config.getId(), enabledTime, config.getCustomResultIndex()); + } else if (analysisType.isForecast()) { + searchLatestResult = createForecastRealtimeInittedEverRequest(config.getId(), enabledTime, config.getCustomResultIndex()); + } else { + throw new IllegalArgumentException("Analysis type is not supported, type: : " + analysisType); + } + + client.search(searchLatestResult, listener); + } +} diff --git a/src/main/java/org/opensearch/timeseries/TaskProfile.java b/src/main/java/org/opensearch/timeseries/TaskProfile.java new file mode 100644 index 000000000..4abf41897 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/TaskProfile.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import java.io.IOException; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.model.TimeSeriesTask; + +public abstract class TaskProfile implements ToXContentObject, Writeable { + + public static final String SHINGLE_SIZE_FIELD = "shingle_size"; + public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; + public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + public static final String NODE_ID_FIELD = "node_id"; + public static final String TASK_ID_FIELD = "task_id"; + public static final String TASK_TYPE_FIELD = "task_type"; + public static final String ENTITY_TASK_PROFILE_FIELD = "entity_task_profiles"; + + protected TaskType task; + protected Integer shingleSize; + protected Long rcfTotalUpdates; + protected Long modelSizeInBytes; + protected String nodeId; + protected String taskId; + protected String taskType; + + public TaskProfile() { + + } + + public TaskProfile(TaskType task) { + this.task = task; + } + + public TaskProfile(String taskId, int shingleSize, long rcfTotalUpdates, long modelSizeInBytes, String nodeId) { + this.taskId = taskId; + this.shingleSize = shingleSize; + this.rcfTotalUpdates = rcfTotalUpdates; + this.modelSizeInBytes = modelSizeInBytes; + this.nodeId = nodeId; + } + + public TaskProfile( + TaskType adTask, + Integer shingleSize, + Long rcfTotalUpdates, + Long modelSizeInBytes, + String nodeId, + String taskId, + String adTaskType + ) { + this.task = adTask; + this.shingleSize = shingleSize; + this.rcfTotalUpdates = rcfTotalUpdates; + this.modelSizeInBytes = modelSizeInBytes; + this.nodeId = nodeId; + this.taskId = taskId; + this.taskType = adTaskType; + } + + public TaskType getTask() { + return task; + } + + public void setTask(TaskType adTask) { + this.task = adTask; + } + + public Integer getShingleSize() { + return shingleSize; + } + + public void setShingleSize(Integer shingleSize) { + this.shingleSize = shingleSize; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public void setRcfTotalUpdates(Long rcfTotalUpdates) { + this.rcfTotalUpdates = rcfTotalUpdates; + } + + public Long getModelSizeInBytes() { + return modelSizeInBytes; + } + + public void setModelSizeInBytes(Long modelSizeInBytes) { + this.modelSizeInBytes = modelSizeInBytes; + } + + public String getNodeId() { + return nodeId; + } + + public void setNodeId(String nodeId) { + this.nodeId = nodeId; + } + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public String getTaskType() { + return taskType; + } + + public void setTaskType(String taskType) { + this.taskType = taskType; + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + TaskProfile that = (TaskProfile) o; + return Objects.equals(task, that.task) + && Objects.equals(shingleSize, that.shingleSize) + && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) + && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) + && Objects.equals(nodeId, that.nodeId) + && Objects.equals(taskId, that.taskId) + && Objects.equals(taskType, that.taskType); + } + + @Generated + @Override + public int hashCode() { + return Objects.hash(task, shingleSize, rcfTotalUpdates, modelSizeInBytes, nodeId, taskId, taskType); + } + + protected void toXContent(XContentBuilder xContentBuilder) throws IOException { + if (task != null) { + xContentBuilder.field(getTaskFieldName(), task); + } + if (shingleSize != null) { + xContentBuilder.field(SHINGLE_SIZE_FIELD, shingleSize); + } + if (rcfTotalUpdates != null) { + xContentBuilder.field(RCF_TOTAL_UPDATES_FIELD, rcfTotalUpdates); + } + if (modelSizeInBytes != null) { + xContentBuilder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + if (nodeId != null) { + xContentBuilder.field(NODE_ID_FIELD, nodeId); + } + if (taskId != null) { + xContentBuilder.field(TASK_ID_FIELD, taskId); + } + if (taskType != null) { + xContentBuilder.field(TASK_TYPE_FIELD, taskType); + } + } + + protected abstract String getTaskFieldName(); +} diff --git a/src/main/java/org/opensearch/timeseries/TaskProfileRunner.java b/src/main/java/org/opensearch/timeseries/TaskProfileRunner.java new file mode 100644 index 000000000..6f29fd244 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/TaskProfileRunner.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.model.TimeSeriesTask; + +/** + * break the cross dependency between ProfileRunner and TaskManager. Instead, both of them depend on TaskProfileRunner. + */ +public interface TaskProfileRunner> { + void getTaskProfile(TaskClass configLevelTask, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java index 7dadac650..9c819886a 100644 --- a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java @@ -12,6 +12,8 @@ package org.opensearch.timeseries; import static java.util.Collections.unmodifiableList; +import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; import java.security.AccessController; import java.security.PrivilegedAction; @@ -33,33 +35,29 @@ import org.apache.logging.log4j.Logger; import org.opensearch.SpecialPermission; import org.opensearch.action.ActionRequest; -import org.opensearch.ad.AnomalyDetectorJobRunner; +import org.opensearch.ad.ADJobProcessor; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.AnomalyDetectorRunner; import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.caching.PriorityCache; -import org.opensearch.ad.cluster.ADClusterEventListener; -import org.opensearch.ad.cluster.ADDataMigrator; -import org.opensearch.ad.cluster.ClusterManagerEventListener; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; -import org.opensearch.ad.ratelimit.CheckPointMaintainRequestAdapter; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; import org.opensearch.ad.rest.RestAnomalyDetectorJobAction; import org.opensearch.ad.rest.RestDeleteAnomalyDetectorAction; import org.opensearch.ad.rest.RestDeleteAnomalyResultsAction; @@ -74,17 +72,14 @@ import org.opensearch.ad.rest.RestSearchTopAnomalyResultAction; import org.opensearch.ad.rest.RestStatsAnomalyDetectorAction; import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.LegacyOpenDistroAnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeCountSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeCountSupplier; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier; import org.opensearch.ad.task.ADBatchTaskRunner; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; @@ -94,6 +89,10 @@ import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionTransportAction; import org.opensearch.ad.transport.ADCancelTaskAction; import org.opensearch.ad.transport.ADCancelTaskTransportAction; +import org.opensearch.ad.transport.ADEntityProfileAction; +import org.opensearch.ad.transport.ADEntityProfileTransportAction; +import org.opensearch.ad.transport.ADProfileAction; +import org.opensearch.ad.transport.ADProfileTransportAction; import org.opensearch.ad.transport.ADResultBulkAction; import org.opensearch.ad.transport.ADResultBulkTransportAction; import org.opensearch.ad.transport.ADStatsNodesAction; @@ -105,17 +104,14 @@ import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultTransportAction; import org.opensearch.ad.transport.CronAction; -import org.opensearch.ad.transport.CronTransportAction; +import org.opensearch.ad.transport.DeleteADModelAction; +import org.opensearch.ad.transport.DeleteADModelTransportAction; import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; import org.opensearch.ad.transport.DeleteAnomalyResultsAction; import org.opensearch.ad.transport.DeleteAnomalyResultsTransportAction; -import org.opensearch.ad.transport.DeleteModelAction; -import org.opensearch.ad.transport.DeleteModelTransportAction; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileTransportAction; -import org.opensearch.ad.transport.EntityResultAction; -import org.opensearch.ad.transport.EntityResultTransportAction; +import org.opensearch.ad.transport.EntityADResultAction; +import org.opensearch.ad.transport.EntityADResultTransportAction; import org.opensearch.ad.transport.ForwardADTaskAction; import org.opensearch.ad.transport.ForwardADTaskTransportAction; import org.opensearch.ad.transport.GetAnomalyDetectorAction; @@ -124,8 +120,6 @@ import org.opensearch.ad.transport.IndexAnomalyDetectorTransportAction; import org.opensearch.ad.transport.PreviewAnomalyDetectorAction; import org.opensearch.ad.transport.PreviewAnomalyDetectorTransportAction; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileTransportAction; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingTransportAction; import org.opensearch.ad.transport.RCFResultAction; @@ -148,11 +142,8 @@ import org.opensearch.ad.transport.ThresholdResultTransportAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.ad.transport.handler.ADSearchHandler; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; -import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; @@ -171,8 +162,91 @@ import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.ForecastJobProcessor; +import org.opensearch.forecast.ForecastTaskProfileRunner; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.model.ForecastResult; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdEntityWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastSaveResultStrategy; +import org.opensearch.forecast.rest.RestDeleteForecasterAction; +import org.opensearch.forecast.rest.RestForecasterJobAction; +import org.opensearch.forecast.rest.RestForecasterSuggestAction; +import org.opensearch.forecast.rest.RestGetForecasterAction; +import org.opensearch.forecast.rest.RestIndexForecasterAction; +import org.opensearch.forecast.rest.RestRunOnceForecasterAction; +import org.opensearch.forecast.rest.RestSearchForecastTasksAction; +import org.opensearch.forecast.rest.RestSearchForecasterAction; +import org.opensearch.forecast.rest.RestSearchForecasterInfoAction; +import org.opensearch.forecast.rest.RestSearchTopForecastResultAction; +import org.opensearch.forecast.rest.RestStatsForecasterAction; +import org.opensearch.forecast.rest.RestValidateForecasterAction; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastNumericSetting; import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.stats.ForecastModelsOnNodeSupplier; +import org.opensearch.forecast.stats.ForecastStats; +import org.opensearch.forecast.stats.suppliers.ForecastModelsOnNodeCountSupplier; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.DeleteForecastModelAction; +import org.opensearch.forecast.transport.DeleteForecastModelTransportAction; +import org.opensearch.forecast.transport.DeleteForecasterAction; +import org.opensearch.forecast.transport.DeleteForecasterTransportAction; +import org.opensearch.forecast.transport.EntityForecastResultAction; +import org.opensearch.forecast.transport.EntityForecastResultTransportAction; +import org.opensearch.forecast.transport.ForecastEntityProfileAction; +import org.opensearch.forecast.transport.ForecastEntityProfileTransportAction; +import org.opensearch.forecast.transport.ForecastProfileAction; +import org.opensearch.forecast.transport.ForecastProfileTransportAction; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultBulkAction; +import org.opensearch.forecast.transport.ForecastResultBulkTransportAction; +import org.opensearch.forecast.transport.ForecastResultTransportAction; +import org.opensearch.forecast.transport.ForecastRunOnceAction; +import org.opensearch.forecast.transport.ForecastRunOnceProfileAction; +import org.opensearch.forecast.transport.ForecastRunOnceProfileTransportAction; +import org.opensearch.forecast.transport.ForecastRunOnceTransportAction; +import org.opensearch.forecast.transport.ForecastSingleStreamResultAction; +import org.opensearch.forecast.transport.ForecastSingleStreamResultTransportAction; +import org.opensearch.forecast.transport.ForecastStatsNodesAction; +import org.opensearch.forecast.transport.ForecastStatsNodesTransportAction; +import org.opensearch.forecast.transport.ForecasterJobAction; +import org.opensearch.forecast.transport.ForecasterJobTransportAction; +import org.opensearch.forecast.transport.GetForecasterAction; +import org.opensearch.forecast.transport.GetForecasterTransportAction; +import org.opensearch.forecast.transport.IndexForecasterAction; +import org.opensearch.forecast.transport.IndexForecasterTransportAction; +import org.opensearch.forecast.transport.SearchForecastTasksAction; +import org.opensearch.forecast.transport.SearchForecastTasksTransportAction; +import org.opensearch.forecast.transport.SearchForecasterAction; +import org.opensearch.forecast.transport.SearchForecasterInfoAction; +import org.opensearch.forecast.transport.SearchForecasterInfoTransportAction; +import org.opensearch.forecast.transport.SearchForecasterTransportAction; +import org.opensearch.forecast.transport.SearchTopForecastResultAction; +import org.opensearch.forecast.transport.SearchTopForecastResultTransportAction; +import org.opensearch.forecast.transport.StatsForecasterAction; +import org.opensearch.forecast.transport.StatsForecasterTransportAction; +import org.opensearch.forecast.transport.StopForecasterAction; +import org.opensearch.forecast.transport.StopForecasterTransportAction; +import org.opensearch.forecast.transport.SuggestForecasterParamAction; +import org.opensearch.forecast.transport.SuggestForecasterParamTransportAction; +import org.opensearch.forecast.transport.ValidateForecasterAction; +import org.opensearch.forecast.transport.ValidateForecasterTransportAction; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.forecast.transport.handler.ForecastSearchHandler; import org.opensearch.jobscheduler.spi.JobSchedulerExtension; import org.opensearch.jobscheduler.spi.ScheduledJobParser; import org.opensearch.jobscheduler.spi.ScheduledJobRunner; @@ -189,19 +263,38 @@ import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.ADDataMigrator; +import org.opensearch.timeseries.cluster.ClusterEventListener; +import org.opensearch.timeseries.cluster.ClusterManagerEventListener; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ThrowingSupplierWrapper; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.settings.TimeSeriesEnabledSetting; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.CronTransportAction; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.watcher.ResourceWatcherService; +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.parkservices.state.RCFCasterMapper; +import com.amazon.randomcutforest.parkservices.state.RCFCasterState; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter; @@ -216,7 +309,7 @@ import io.protostuff.runtime.RuntimeSchema; /** - * Entry point of AD plugin. + * Entry point of time series analytics plugin. */ public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, ScriptPlugin, JobSchedulerExtension { @@ -236,27 +329,32 @@ public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, S public static final String FORECAST_FORECASTERS_URI = FORECAST_BASE_URI + "/forecasters"; public static final String FORECAST_THREAD_POOL_PREFIX = "opensearch.forecast."; public static final String FORECAST_THREAD_POOL_NAME = "forecast-threadpool"; - public static final String FORECAST_BATCH_TASK_THREAD_POOL_NAME = "forecast-batch-task-threadpool"; public static final String TIME_SERIES_JOB_TYPE = "opensearch_time_series_analytics"; private static Gson gson; private ADIndexManagement anomalyDetectionIndices; + private ForecastIndexManagement forecastIndices; private AnomalyDetectorRunner anomalyDetectorRunner; private Client client; private ClusterService clusterService; private ThreadPool threadPool; private ADStats adStats; + private ForecastStats forecastStats; private ClientUtil clientUtil; private SecurityClientUtil securityClientUtil; private DiscoveryNodeFilterer nodeFilter; private IndexUtils indexUtils; private ADTaskManager adTaskManager; + private ForecastTaskManager forecastTaskManager; private ADBatchTaskRunner adBatchTaskRunner; // package private for testing GenericObjectPool serializeRCFBufferPool; private NodeStateManager stateManager; private ExecuteADResultResponseRecorder adResultResponseRecorder; + private ExecuteForecastResultResponseRecorder forecastResultResponseRecorder; + private ADIndexJobActionHandler adIndexJobActionHandler; + private ForecastIndexJobActionHandler forecastIndexJobActionHandler; static { SpecialPermission.check(); @@ -277,14 +375,16 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance(); - jobRunner.setClient(client); - jobRunner.setThreadPool(threadPool); - jobRunner.setSettings(settings); - jobRunner.setAnomalyDetectionIndices(anomalyDetectionIndices); - jobRunner.setAdTaskManager(adTaskManager); - jobRunner.setNodeStateManager(stateManager); - jobRunner.setExecuteADResultResponseRecorder(adResultResponseRecorder); + // AD + ADJobProcessor adJobRunner = ADJobProcessor.getInstance(); + adJobRunner.setClient(client); + adJobRunner.setThreadPool(threadPool); + adJobRunner.registerSettings(settings); + adJobRunner.setIndexManagement(anomalyDetectionIndices); + adJobRunner.setTaskManager(adTaskManager); + adJobRunner.setNodeStateManager(stateManager); + adJobRunner.setExecuteResultResponseRecorder(adResultResponseRecorder); + adJobRunner.setIndexJobActionHandler(adIndexJobActionHandler); RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction(); RestIndexAnomalyDetectorAction restIndexAnomalyDetectorAction = new RestIndexAnomalyDetectorAction(settings, clusterService); @@ -301,8 +401,33 @@ public List getRestHandlers( RestSearchTopAnomalyResultAction searchTopAnomalyResultAction = new RestSearchTopAnomalyResultAction(); RestValidateAnomalyDetectorAction validateAnomalyDetectorAction = new RestValidateAnomalyDetectorAction(settings, clusterService); + // Forecast + RestIndexForecasterAction restIndexForecasterAction = new RestIndexForecasterAction(settings, clusterService); + RestForecasterJobAction restForecasterJobAction = new RestForecasterJobAction(); + RestGetForecasterAction restGetForecasterAction = new RestGetForecasterAction(); + RestDeleteForecasterAction deleteForecasterAction = new RestDeleteForecasterAction(); + RestSearchForecasterAction searchForecasterAction = new RestSearchForecasterAction(); + RestSearchForecasterInfoAction searchForecasterInfoAction = new RestSearchForecasterInfoAction(); + RestSearchTopForecastResultAction searchTopForecastResultAction = new RestSearchTopForecastResultAction(); + RestSearchForecastTasksAction searchForecastTasksAction = new RestSearchForecastTasksAction(); + RestStatsForecasterAction statsForecasterAction = new RestStatsForecasterAction(forecastStats, this.nodeFilter); + RestRunOnceForecasterAction runOnceForecasterAction = new RestRunOnceForecasterAction(); + RestValidateForecasterAction validateForecasterAction = new RestValidateForecasterAction(settings, clusterService); + RestForecasterSuggestAction suggestForecasterParamAction = new RestForecasterSuggestAction(settings, clusterService); + + ForecastJobProcessor forecastJobRunner = ForecastJobProcessor.getInstance(); + forecastJobRunner.setClient(client); + forecastJobRunner.setThreadPool(threadPool); + forecastJobRunner.registerSettings(settings); + forecastJobRunner.setIndexManagement(forecastIndices); + forecastJobRunner.setTaskManager(forecastTaskManager); + forecastJobRunner.setNodeStateManager(stateManager); + forecastJobRunner.setExecuteResultResponseRecorder(forecastResultResponseRecorder); + forecastJobRunner.setIndexJobActionHandler(forecastIndexJobActionHandler); + return ImmutableList .of( + // AD restGetAnomalyDetectorAction, restIndexAnomalyDetectorAction, searchAnomalyDetectorAction, @@ -316,7 +441,20 @@ public List getRestHandlers( previewAnomalyDetectorAction, deleteAnomalyResultsAction, searchTopAnomalyResultAction, - validateAnomalyDetectorAction + validateAnomalyDetectorAction, + // Forecast + restIndexForecasterAction, + restForecasterJobAction, + restGetForecasterAction, + deleteForecasterAction, + searchForecasterAction, + searchForecasterInfoAction, + searchTopForecastResultAction, + searchForecastTasksAction, + statsForecasterAction, + runOnceForecasterAction, + validateForecasterAction, + suggestForecasterParamAction ); } @@ -339,30 +477,51 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - ADEnabledSetting.getInstance().init(clusterService); - ADNumericSetting.getInstance().init(clusterService); + // ===================== + // Common components + // ===================== this.client = client; this.threadPool = threadPool; Settings settings = environment.settings(); this.clientUtil = new ClientUtil(client); - this.indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameExpressionResolver); + this.indexUtils = new IndexUtils(clusterService, indexNameExpressionResolver); this.nodeFilter = new DiscoveryNodeFilterer(clusterService); - // convert from checked IOException to unchecked RuntimeException - this.anomalyDetectionIndices = ThrowingSupplierWrapper - .throwingSupplierWrapper( - () -> new ADIndexManagement( - client, - clusterService, - threadPool, - settings, - nodeFilter, - TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES - ) - ) - .get(); this.clusterService = clusterService; - Imputer imputer = new LinearUniformImputer(true); + + JvmService jvmService = new JvmService(environment.settings()); + RandomCutForestMapper rcfMapper = new RandomCutForestMapper(); + rcfMapper.setSaveExecutorContextEnabled(true); + rcfMapper.setSaveTreeStateEnabled(true); + rcfMapper.setPartialTreeStateEnabled(true); + V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter(); + + CircuitBreakerService circuitBreakerService = new CircuitBreakerService(jvmService).init(); + + long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + + serializeRCFBufferPool = AccessController.doPrivileged(new PrivilegedAction>() { + @Override + public GenericObjectPool run() { + return new GenericObjectPool<>(new BasePooledObjectFactory() { + @Override + public LinkedBuffer create() throws Exception { + return LinkedBuffer.allocate(TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES); + } + + @Override + public PooledObject wrap(LinkedBuffer obj) { + return new DefaultPooledObject<>(obj); + } + }); + } + }); + serializeRCFBufferPool.setMaxTotal(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMaxIdle(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMinIdle(0); + serializeRCFBufferPool.setBlockWhenExhausted(false); + serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); + stateManager = new NodeStateManager( client, xContentRegistry, @@ -375,6 +534,7 @@ public Collection createComponents( TimeSeriesSettings.BACKOFF_MINUTES ); securityClientUtil = new SecurityClientUtil(stateManager, settings); + SearchFeatureDao searchFeatureDao = new SearchFeatureDao( client, xContentRegistry, @@ -385,27 +545,12 @@ public Collection createComponents( TimeSeriesSettings.NUM_SAMPLES_PER_TREE ); - JvmService jvmService = new JvmService(environment.settings()); - RandomCutForestMapper mapper = new RandomCutForestMapper(); - mapper.setSaveExecutorContextEnabled(true); - mapper.setSaveTreeStateEnabled(true); - mapper.setPartialTreeStateEnabled(true); - V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter(); - - double modelMaxSizePercent = AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(settings); - - CircuitBreakerService adCircuitBreakerService = new CircuitBreakerService(jvmService).init(); - - MemoryTracker memoryTracker = new MemoryTracker(jvmService, modelMaxSizePercent, clusterService, adCircuitBreakerService); - FeatureManager featureManager = new FeatureManager( searchFeatureDao, imputer, getClock(), - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -415,36 +560,36 @@ public Collection createComponents( AD_THREAD_POOL_NAME ); - long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + Random random = new Random(42); - serializeRCFBufferPool = AccessController.doPrivileged(new PrivilegedAction>() { - @Override - public GenericObjectPool run() { - return new GenericObjectPool<>(new BasePooledObjectFactory() { - @Override - public LinkedBuffer create() throws Exception { - return LinkedBuffer.allocate(TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES); - } + // ===================== + // AD components + // ===================== + ADEnabledSetting.getInstance().init(clusterService); + ADNumericSetting.getInstance().init(clusterService); + // convert from checked IOException to unchecked RuntimeException + this.anomalyDetectionIndices = ThrowingSupplierWrapper + .throwingSupplierWrapper( + () -> new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ) + ) + .get(); - @Override - public PooledObject wrap(LinkedBuffer obj) { - return new DefaultPooledObject<>(obj); - } - }); - } - }); - serializeRCFBufferPool.setMaxTotal(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); - serializeRCFBufferPool.setMaxIdle(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); - serializeRCFBufferPool.setMinIdle(0); - serializeRCFBufferPool.setBlockWhenExhausted(false); - serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); + double adModelMaxSizePercent = AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(settings); + + MemoryTracker adMemoryTracker = new MemoryTracker(jvmService, adModelMaxSizePercent, clusterService, circuitBreakerService); - CheckpointDao checkpoint = new CheckpointDao( + ADCheckpointDao adCheckpoint = new ADCheckpointDao( client, clientUtil, - ADCommonName.CHECKPOINT_INDEX_NAME, gson, - mapper, + rcfMapper, converter, new ThresholdedRandomCutForestMapper(), AccessController @@ -457,30 +602,30 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.MAX_CHECKPOINT_BYTES, serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - 1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE + 1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + getClock() ); - Random random = new Random(42); - - CacheProvider cacheProvider = new CacheProvider(); + ADCacheProvider adCacheProvider = new ADCacheProvider(); - CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( - cacheProvider, - checkpoint, - ADCommonName.CHECKPOINT_INDEX_NAME, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, - getClock(), - clusterService, - settings - ); + CheckPointMaintainRequestAdapter adAdapter = + new CheckPointMaintainRequestAdapter<>( + adCheckpoint, + ADCommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + getClock(), + clusterService, + settings, + adCacheProvider + ); - CheckpointWriteWorker checkpointWriteQueue = new CheckpointWriteWorker( + ADCheckpointWriteWorker adCheckpointWriteQueue = new ADCheckpointWriteWorker( heapSizeBytes, TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -489,20 +634,20 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, TimeSeriesSettings.QUEUE_MAINTENANCE, - checkpoint, + adCheckpoint, ADCommonName.CHECKPOINT_INDEX_NAME, TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, TimeSeriesSettings.HOURLY_MAINTENANCE ); - CheckpointMaintainWorker checkpointMaintainQueue = new CheckpointMaintainWorker( + ADCheckpointMaintainWorker adCheckpointMaintainQueue = new ADCheckpointMaintainWorker( heapSizeBytes, TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -510,33 +655,35 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointWriteQueue, + adCheckpointWriteQueue, TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, - adapter + adAdapter::convert ); - EntityCache cache = new PriorityCache( - checkpoint, + ADPriorityCache adPriorityCache = new ADPriorityCache( + adCheckpoint, AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE.get(settings), AnomalyDetectorSettings.AD_CHECKPOINT_TTL, AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, - memoryTracker, + adMemoryTracker, TimeSeriesSettings.NUM_TREES, getClock(), clusterService, TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, - checkpointWriteQueue, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointMaintainQueue, settings, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + adCheckpointWriteQueue, + adCheckpointMaintainQueue ); - cacheProvider.set(cache); + // cache provider allows us to break circular dependency among PriorityCache, CacheBuffer, + // CheckPointMaintainRequestAdapter, and CheckpointMaintainWorker + adCacheProvider.set(adPriorityCache); - EntityColdStarter entityColdStarter = new EntityColdStarter( + ADEntityColdStart adEntityColdStarter = new ADEntityColdStart( getClock(), threadPool, stateManager, @@ -546,23 +693,71 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.NUM_MIN_SAMPLES, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, TimeSeriesSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + adCheckpointWriteQueue, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()) + ); + + ADModelManager adModelManager = new ADModelManager( + adCheckpoint, + getClock(), + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + TimeSeriesSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + adEntityColdStarter, + featureManager, + adMemoryTracker, + settings, + clusterService + ); + + ADIndexMemoryPressureAwareResultHandler adIndexMemoryPressureAwareResultHandler = new ADIndexMemoryPressureAwareResultHandler( + client, + anomalyDetectionIndices + ); + + ADResultWriteWorker adResultWriteQueue = new ADResultWriteWorker( + heapSizeBytes, + TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + adIndexMemoryPressureAwareResultHandler, + xContentRegistry, + stateManager, + TimeSeriesSettings.HOURLY_MAINTENANCE ); - EntityColdStartWorker coldstartQueue = new EntityColdStartWorker( + ADSaveResultStrategy adSaveResultStrategy = new ADSaveResultStrategy( + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + adResultWriteQueue + ); + + ADColdStartWorker adColdstartQueue = new ADColdStartWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -571,47 +766,324 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, TimeSeriesSettings.QUEUE_MAINTENANCE, - entityColdStarter, + adEntityColdStarter, TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, - cacheProvider + adPriorityCache, + adModelManager, + adSaveResultStrategy ); - ModelManager modelManager = new ModelManager( - checkpoint, + Map> adStatsMap = ImmutableMap + .>builder() + // ad stats + .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put( + StatNames.ANOMALY_RESULTS_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) + ) + .put( + StatNames.AD_MODELS_CHECKPOINT_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.CHECKPOINT_INDEX_NAME)) + ) + .put( + StatNames.ANOMALY_DETECTION_STATE_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.DETECTION_STATE_INDEX)) + ) + .put(StatNames.DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.HC_DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put( + StatNames.MODEL_INFORMATION.getName(), + new TimeSeriesStat<>(false, new ADModelsOnNodeSupplier(adModelManager, adCacheProvider, settings, clusterService)) + ) + .put( + StatNames.CONFIG_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) + ) + .put( + StatNames.JOB_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + ) + .put( + StatNames.MODEL_COUNT.getName(), + new TimeSeriesStat<>(false, new ADModelsOnNodeCountSupplier(adModelManager, adCacheProvider)) + ) + .build(); + + adStats = new ADStats(adStatsMap); + + ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, getClock(), - TimeSeriesSettings.NUM_TREES, - TimeSeriesSettings.NUM_SAMPLES_PER_TREE, - TimeSeriesSettings.TIME_DECAY, - TimeSeriesSettings.NUM_MIN_SAMPLES, - TimeSeriesSettings.THRESHOLD_MIN_PVALUE, - AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + adModelManager, + adCheckpoint, + adColdstartQueue, + stateManager, + anomalyDetectionIndices, + adCacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, - entityColdStarter, - featureManager, - memoryTracker, + adCheckpointWriteQueue, + adStats, + adSaveResultStrategy + ); + + ADColdEntityWorker adColdEntityQueue = new ADColdEntityWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, settings, - clusterService + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + adCheckpointReadQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager ); - MultiEntityResultHandler multiEntityResultHandler = new MultiEntityResultHandler( + ADDataMigrator adDataMigrator = new ADDataMigrator(client, clusterService, xContentRegistry, anomalyDetectionIndices); + + anomalyDetectorRunner = new AnomalyDetectorRunner(adModelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + + ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, adMemoryTracker); + + ResultBulkIndexingHandler anomalyResultBulkIndexHandler = + new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); + + ResultBulkIndexingHandler anomalyResultHandler = new ResultBulkIndexingHandler<>( client, settings, threadPool, + ANOMALY_RESULT_INDEX_ALIAS, anomalyDetectionIndices, this.clientUtil, this.indexUtils, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); - ResultWriteWorker resultWriteQueue = new ResultWriteWorker( + adResultResponseRecorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + stateManager, + adTaskCacheManager, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); + + adIndexJobActionHandler = new ADIndexJobActionHandler( + client, + anomalyDetectionIndices, + xContentRegistry, + adTaskManager, + adResultResponseRecorder, + stateManager, + settings + ); + + // ===================== + // forecast components + // ===================== + ForecastEnabledSetting.getInstance().init(clusterService); + ForecastNumericSetting.getInstance().init(clusterService); + + forecastIndices = ThrowingSupplierWrapper + .throwingSupplierWrapper( + () -> new ForecastIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + ForecastSettings.FORECAST_MAX_UPDATE_RETRY_TIMES + ) + ) + .get(); + + double forecastModelMaxSizePercent = ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE.get(settings); + + MemoryTracker forecastMemoryTracker = new MemoryTracker( + jvmService, + forecastModelMaxSizePercent, + clusterService, + circuitBreakerService + ); + + ForecastCheckpointDao forecastCheckpoint = new ForecastCheckpointDao( + client, + clientUtil, + gson, + TimeSeriesSettings.MAX_CHECKPOINT_BYTES, + serializeRCFBufferPool, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, + forecastIndices, + new RCFCasterMapper(), + AccessController.doPrivileged((PrivilegedAction>) () -> RuntimeSchema.getSchema(RCFCasterState.class)), + getClock() + ); + + ForecastCacheProvider forecastCacheProvider = new ForecastCacheProvider(); + + CheckPointMaintainRequestAdapter forecastAdapter = + new CheckPointMaintainRequestAdapter( + forecastCheckpoint, + ForecastIndex.CHECKPOINT.getIndexName(), + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + getClock(), + clusterService, + settings, + forecastCacheProvider + ); + + ForecastCheckpointWriteWorker forecastCheckpointWriteQueue = new ForecastCheckpointWriteWorker( + heapSizeBytes, + TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastCheckpoint, + ForecastIndex.CHECKPOINT.getIndexName(), + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + TimeSeriesSettings.HOURLY_MAINTENANCE + ); + + ForecastCheckpointMaintainWorker forecastCheckpointMaintainQueue = new ForecastCheckpointMaintainWorker( + heapSizeBytes, + TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + forecastCheckpointWriteQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + forecastAdapter::convert + ); + + ForecastPriorityCache forecastPriorityCache = new ForecastPriorityCache( + forecastCheckpoint, + ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE.get(settings), + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + adMemoryTracker, + TimeSeriesSettings.NUM_TREES, + getClock(), + clusterService, + TimeSeriesSettings.HOURLY_MAINTENANCE, + threadPool, + FORECAST_THREAD_POOL_NAME, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + settings, + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + forecastCheckpointWriteQueue, + forecastCheckpointMaintainQueue + ); + + // cache provider allows us to break circular dependency among PriorityCache, CacheBuffer, + // CheckPointMaintainRequestAdapter, and CheckpointMaintainWorker + forecastCacheProvider.set(forecastPriorityCache); + + ForecastColdStart forecastColdStarter = new ForecastColdStart( + getClock(), + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + TimeSeriesSettings.HOURLY_MAINTENANCE, + forecastCheckpointWriteQueue, + (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()), + -1, // no hard coded random seed + -1, // interpolation is disabled so we don't need to specify the number of sampled points + TimeSeriesSettings.MAX_COLD_START_ROUNDS + ); + + ForecastModelManager forecastModelManager = new ForecastModelManager( + forecastCheckpoint, + getClock(), + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + forecastColdStarter, + forecastMemoryTracker, + featureManager + ); + + ForecastIndexMemoryPressureAwareResultHandler forecastIndexMemoryPressureAwareResultHandler = + new ForecastIndexMemoryPressureAwareResultHandler(client, forecastIndices); + + ForecastResultWriteWorker forecastResultWriteQueue = new ForecastResultWriteWorker( heapSizeBytes, TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, - AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -620,62 +1092,87 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, TimeSeriesSettings.QUEUE_MAINTENANCE, - multiEntityResultHandler, + forecastIndexMemoryPressureAwareResultHandler, xContentRegistry, stateManager, TimeSeriesSettings.HOURLY_MAINTENANCE ); - Map> stats = ImmutableMap - .>builder() - .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + ForecastSaveResultStrategy forecastSaveResultStrategy = new ForecastSaveResultStrategy( + forecastIndices.getSchemaVersion(ForecastIndex.RESULT), + forecastResultWriteQueue + ); + + ForecastColdStartWorker forecastColdstartQueue = new ForecastColdStartWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastColdStarter, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + forecastPriorityCache, + forecastModelManager, + forecastSaveResultStrategy + ); + + Map> forecastStatsMap = ImmutableMap + .>builder() + // forecast stats + .put(StatNames.FORECAST_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) .put( - StatNames.MODEL_INFORMATION.getName(), - new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + StatNames.FORECAST_RESULTS_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.RESULT.getIndexName())) ) .put( - StatNames.ANOMALY_DETECTORS_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) + StatNames.FORECAST_MODELS_CHECKPOINT_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.CHECKPOINT.getIndexName())) ) .put( - StatNames.ANOMALY_RESULTS_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) + StatNames.FORECAST_STATE_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.STATE.getIndexName())) ) + .put(StatNames.FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.HC_FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) .put( - StatNames.MODELS_CHECKPOINT_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.CHECKPOINT_INDEX_NAME)) + StatNames.MODEL_INFORMATION.getName(), + new TimeSeriesStat<>(false, new ForecastModelsOnNodeSupplier(forecastCacheProvider, settings, clusterService)) ) .put( - StatNames.ANOMALY_DETECTION_JOB_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + StatNames.CONFIG_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) ) .put( - StatNames.ANOMALY_DETECTION_STATE_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.DETECTION_STATE_INDEX)) + StatNames.JOB_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) ) - .put(StatNames.DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.MODEL_COUNT.getName(), new ADStat<>(false, new ModelsOnNodeCountSupplier(modelManager, cacheProvider))) - .put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.MODEL_COUNT.getName(), new TimeSeriesStat<>(false, new ForecastModelsOnNodeCountSupplier(forecastCacheProvider))) .build(); - adStats = new ADStats(stats); + forecastStats = new ForecastStats(forecastStatsMap); - CheckpointReadWorker checkpointReadQueue = new CheckpointReadWorker( + ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -684,25 +1181,25 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, TimeSeriesSettings.QUEUE_MAINTENANCE, - modelManager, - checkpoint, - coldstartQueue, - resultWriteQueue, + forecastModelManager, + forecastCheckpoint, + forecastColdstartQueue, stateManager, - anomalyDetectionIndices, - cacheProvider, + forecastIndices, + forecastCacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - adStats + forecastCheckpointWriteQueue, + forecastStats, + forecastSaveResultStrategy ); - ColdEntityWorker coldEntityQueue = new ColdEntityWorker( + ForecastColdEntityWorker forecastColdEntityQueue = new ForecastColdEntityWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, @@ -710,17 +1207,68 @@ public PooledObject wrap(LinkedBuffer obj) { TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointReadQueue, + forecastCheckpointReadQueue, TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager ); - ADDataMigrator dataMigrator = new ADDataMigrator(client, clusterService, xContentRegistry, anomalyDetectionIndices); - HashRing hashRing = new HashRing(nodeFilter, getClock(), settings, client, clusterService, dataMigrator, modelManager); + TaskCacheManager forecastTaskCacheManager = new TaskCacheManager(settings, clusterService); + + forecastTaskManager = new ForecastTaskManager( + forecastTaskCacheManager, + client, + xContentRegistry, + forecastIndices, + clusterService, + settings, + threadPool, + stateManager + ); + + ResultBulkIndexingHandler forecastResultHandler = + new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ForecastIndex.RESULT.getIndexName(), + forecastIndices, + this.clientUtil, + this.indexUtils, + clusterService, + ForecastSettings.FORECAST_BACKOFF_INITIAL_DELAY, + ForecastSettings.FORECAST_MAX_RETRY_FOR_BACKOFF + ); + + ForecastSearchHandler forecastSearchHandler = new ForecastSearchHandler(settings, clusterService, client); + + forecastResultResponseRecorder = new ExecuteForecastResultResponseRecorder( + forecastIndices, + forecastResultHandler, + forecastTaskManager, + nodeFilter, + threadPool, + client, + stateManager, + forecastTaskCacheManager, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); + + forecastIndexJobActionHandler = new ForecastIndexJobActionHandler( + client, + forecastIndices, + xContentRegistry, + forecastTaskManager, + forecastResultResponseRecorder, + stateManager, + settings + ); - anomalyDetectorRunner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + // ===================== + // common components, need AD/forecasting components to initialize + // ===================== + HashRing hashRing = new HashRing(nodeFilter, getClock(), settings, client, clusterService, adDataMigrator, adModelManager); + ADTaskProfileRunner adTaskProfileRunner = new ADTaskProfileRunner(hashRing, client); - ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); adTaskManager = new ADTaskManager( settings, clusterService, @@ -730,24 +1278,18 @@ public PooledObject wrap(LinkedBuffer obj) { nodeFilter, hashRing, adTaskCacheManager, - threadPool - ); - AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler = new AnomalyResultBulkIndexHandler( - client, - settings, threadPool, - this.clientUtil, - this.indexUtils, - clusterService, - anomalyDetectionIndices + stateManager, + adTaskProfileRunner ); + adBatchTaskRunner = new ADBatchTaskRunner( settings, threadPool, clusterService, client, securityClientUtil, - adCircuitBreakerService, + circuitBreakerService, featureManager, adTaskManager, anomalyDetectionIndices, @@ -756,51 +1298,23 @@ public PooledObject wrap(LinkedBuffer obj) { adTaskCacheManager, searchFeatureDao, hashRing, - modelManager - ); - - ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); - - AnomalyIndexHandler anomalyResultHandler = new AnomalyIndexHandler( - client, - settings, - threadPool, - ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, - anomalyDetectionIndices, - this.clientUtil, - this.indexUtils, - clusterService - ); - - adResultResponseRecorder = new ExecuteADResultResponseRecorder( - anomalyDetectionIndices, - anomalyResultHandler, - adTaskManager, - nodeFilter, - threadPool, - client, - stateManager, - adTaskCacheManager, - TimeSeriesSettings.NUM_MIN_SAMPLES + adModelManager ); // return objects used by Guice to inject dependencies for e.g., // transport action handler constructors return ImmutableList .of( - anomalyDetectionIndices, - anomalyDetectorRunner, + // common components searchFeatureDao, imputer, gson, jvmService, hashRing, featureManager, - modelManager, stateManager, - new ADClusterEventListener(clusterService, hashRing), - adCircuitBreakerService, - adStats, + new ClusterEventListener(clusterService, hashRing), + circuitBreakerService, new ClusterManagerEventListener( clusterService, threadPool, @@ -809,23 +1323,51 @@ public PooledObject wrap(LinkedBuffer obj) { clientUtil, nodeFilter, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + ForecastSettings.FORECAST_CHECKPOINT_TTL, settings ), nodeFilter, - multiEntityResultHandler, - checkpoint, - cacheProvider, + // AD components + anomalyDetectionIndices, + anomalyDetectorRunner, + adModelManager, + adStats, + adIndexMemoryPressureAwareResultHandler, + adCheckpoint, + adCacheProvider, adTaskManager, adBatchTaskRunner, adSearchHandler, - coldstartQueue, - resultWriteQueue, - checkpointReadQueue, - checkpointWriteQueue, - coldEntityQueue, - entityColdStarter, + adColdstartQueue, + adResultWriteQueue, + adCheckpointReadQueue, + adCheckpointWriteQueue, + adColdEntityQueue, + adEntityColdStarter, adTaskCacheManager, - adResultResponseRecorder + adResultResponseRecorder, + adIndexJobActionHandler, + adSaveResultStrategy, + new ADTaskProfileRunner(hashRing, client), + // forecast components + forecastIndices, + forecastStats, + forecastModelManager, + forecastIndexMemoryPressureAwareResultHandler, + forecastCheckpoint, + forecastCacheProvider, + forecastColdstartQueue, + forecastResultWriteQueue, + forecastCheckpointReadQueue, + forecastCheckpointWriteQueue, + forecastColdEntityQueue, + forecastColdStarter, + forecastTaskManager, + forecastSearchHandler, + forecastIndexJobActionHandler, + forecastTaskCacheManager, + forecastSaveResultStrategy, + new ForecastTaskProfileRunner() ); } @@ -857,14 +1399,29 @@ public List> getExecutorBuilders(Settings settings) { Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 8), TimeValue.timeValueMinutes(10), AD_THREAD_POOL_PREFIX + AD_BATCH_TASK_THREAD_POOL_NAME + ), + new ScalingExecutorBuilder( + FORECAST_THREAD_POOL_NAME, + 1, + // this pool is used by both real time and run once. + // HCAD can be heavy after supporting 1 million entities. + // Limit to use at most 3/4 of the processors. + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) * 3 / 4), + TimeValue.timeValueMinutes(10), + FORECAST_THREAD_POOL_PREFIX + FORECAST_THREAD_POOL_NAME ) ); } @Override public List> getSettings() { - List> enabledSetting = ADEnabledSetting.getInstance().getSettings(); - List> numericSetting = ADNumericSetting.getInstance().getSettings(); + List> adEnabledSetting = ADEnabledSetting.getInstance().getSettings(); + List> adNumericSetting = ADNumericSetting.getInstance().getSettings(); + + List> forecastEnabledSetting = ForecastEnabledSetting.getInstance().getSettings(); + List> forecastNumericSetting = ForecastNumericSetting.getInstance().getSettings(); + + List> timeSeriesEnabledSetting = TimeSeriesEnabledSetting.getInstance().getSettings(); List> systemSetting = ImmutableList .of( @@ -960,6 +1517,15 @@ public List> getSettings() { // ====================================== // Forecast settings // ====================================== + // HC forecasting cache + ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE, + // config parameters + ForecastSettings.FORECAST_INTERVAL, + ForecastSettings.FORECAST_WINDOW_DELAY, + // Fault tolerance + ForecastSettings.FORECAST_BACKOFF_MINUTES, + ForecastSettings.FORECAST_BACKOFF_INITIAL_DELAY, + ForecastSettings.FORECAST_MAX_RETRY_FOR_BACKOFF, // result index rollover ForecastSettings.FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD, ForecastSettings.FORECAST_RESULT_HISTORY_RETENTION_PERIOD, @@ -972,6 +1538,40 @@ public List> getSettings() { ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT, ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT, ForecastSettings.FORECAST_MAX_PRIMARY_SHARDS, + // restful apis + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + // resource constraint + ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS, + ForecastSettings.MAX_HC_FORECASTERS, + // Security + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + // Historical + ForecastSettings.MAX_OLD_TASK_DOCS_PER_FORECASTER, + // rate limiting + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_COLD_START_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + ForecastSettings.FORECAST_CHECKPOINT_TTL, + // query limit + ForecastSettings.FORECAST_MAX_ENTITIES_PER_INTERVAL, + ForecastSettings.FORECAST_PAGE_SIZE, + // stats/profile API + ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE, + // clean resource + ForecastSettings.DELETE_FORECAST_RESULT_WHEN_DELETE_FORECASTER, // ====================================== // Common settings // ====================================== @@ -984,7 +1584,14 @@ public List> getSettings() { ); return unmodifiableList( Stream - .of(enabledSetting.stream(), systemSetting.stream(), numericSetting.stream()) + .of( + adEnabledSetting.stream(), + forecastEnabledSetting.stream(), + timeSeriesEnabledSetting.stream(), + systemSetting.stream(), + adNumericSetting.stream(), + forecastNumericSetting.stream() + ) .reduce(Stream::concat) .orElseGet(Stream::empty) .collect(Collectors.toList()) @@ -1010,14 +1617,15 @@ public List getNamedXContent() { public List> getActions() { return Arrays .asList( - new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), + // AD + new ActionHandler<>(DeleteADModelAction.INSTANCE, DeleteADModelTransportAction.class), new ActionHandler<>(StopDetectorAction.INSTANCE, StopDetectorTransportAction.class), new ActionHandler<>(RCFResultAction.INSTANCE, RCFResultTransportAction.class), new ActionHandler<>(ThresholdResultAction.INSTANCE, ThresholdResultTransportAction.class), new ActionHandler<>(AnomalyResultAction.INSTANCE, AnomalyResultTransportAction.class), new ActionHandler<>(CronAction.INSTANCE, CronTransportAction.class), new ActionHandler<>(ADStatsNodesAction.INSTANCE, ADStatsNodesTransportAction.class), - new ActionHandler<>(ProfileAction.INSTANCE, ProfileTransportAction.class), + new ActionHandler<>(ADProfileAction.INSTANCE, ADProfileTransportAction.class), new ActionHandler<>(RCFPollingAction.INSTANCE, RCFPollingTransportAction.class), new ActionHandler<>(SearchAnomalyDetectorAction.INSTANCE, SearchAnomalyDetectorTransportAction.class), new ActionHandler<>(SearchAnomalyResultAction.INSTANCE, SearchAnomalyResultTransportAction.class), @@ -1028,8 +1636,8 @@ public List getNamedXContent() { new ActionHandler<>(IndexAnomalyDetectorAction.INSTANCE, IndexAnomalyDetectorTransportAction.class), new ActionHandler<>(AnomalyDetectorJobAction.INSTANCE, AnomalyDetectorJobTransportAction.class), new ActionHandler<>(ADResultBulkAction.INSTANCE, ADResultBulkTransportAction.class), - new ActionHandler<>(EntityResultAction.INSTANCE, EntityResultTransportAction.class), - new ActionHandler<>(EntityProfileAction.INSTANCE, EntityProfileTransportAction.class), + new ActionHandler<>(EntityADResultAction.INSTANCE, EntityADResultTransportAction.class), + new ActionHandler<>(ADEntityProfileAction.INSTANCE, ADEntityProfileTransportAction.class), new ActionHandler<>(SearchAnomalyDetectorInfoAction.INSTANCE, SearchAnomalyDetectorInfoTransportAction.class), new ActionHandler<>(PreviewAnomalyDetectorAction.INSTANCE, PreviewAnomalyDetectorTransportAction.class), new ActionHandler<>(ADBatchAnomalyResultAction.INSTANCE, ADBatchAnomalyResultTransportAction.class), @@ -1039,7 +1647,30 @@ public List getNamedXContent() { new ActionHandler<>(ForwardADTaskAction.INSTANCE, ForwardADTaskTransportAction.class), new ActionHandler<>(DeleteAnomalyResultsAction.INSTANCE, DeleteAnomalyResultsTransportAction.class), new ActionHandler<>(SearchTopAnomalyResultAction.INSTANCE, SearchTopAnomalyResultTransportAction.class), - new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class) + new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class), + // forecast + new ActionHandler<>(IndexForecasterAction.INSTANCE, IndexForecasterTransportAction.class), + new ActionHandler<>(ForecastResultAction.INSTANCE, ForecastResultTransportAction.class), + new ActionHandler<>(EntityForecastResultAction.INSTANCE, EntityForecastResultTransportAction.class), + new ActionHandler<>(ForecastResultBulkAction.INSTANCE, ForecastResultBulkTransportAction.class), + new ActionHandler<>(ForecastSingleStreamResultAction.INSTANCE, ForecastSingleStreamResultTransportAction.class), + new ActionHandler<>(ForecasterJobAction.INSTANCE, ForecasterJobTransportAction.class), + new ActionHandler<>(StopForecasterAction.INSTANCE, StopForecasterTransportAction.class), + new ActionHandler<>(DeleteForecastModelAction.INSTANCE, DeleteForecastModelTransportAction.class), + new ActionHandler<>(GetForecasterAction.INSTANCE, GetForecasterTransportAction.class), + new ActionHandler<>(DeleteForecasterAction.INSTANCE, DeleteForecasterTransportAction.class), + new ActionHandler<>(SearchForecasterAction.INSTANCE, SearchForecasterTransportAction.class), + new ActionHandler<>(SearchForecasterInfoAction.INSTANCE, SearchForecasterInfoTransportAction.class), + new ActionHandler<>(SearchTopForecastResultAction.INSTANCE, SearchTopForecastResultTransportAction.class), + new ActionHandler<>(ForecastEntityProfileAction.INSTANCE, ForecastEntityProfileTransportAction.class), + new ActionHandler<>(ForecastProfileAction.INSTANCE, ForecastProfileTransportAction.class), + new ActionHandler<>(SearchForecastTasksAction.INSTANCE, SearchForecastTasksTransportAction.class), + new ActionHandler<>(StatsForecasterAction.INSTANCE, StatsForecasterTransportAction.class), + new ActionHandler<>(ForecastStatsNodesAction.INSTANCE, ForecastStatsNodesTransportAction.class), + new ActionHandler<>(ForecastRunOnceAction.INSTANCE, ForecastRunOnceTransportAction.class), + new ActionHandler<>(ForecastRunOnceProfileAction.INSTANCE, ForecastRunOnceProfileTransportAction.class), + new ActionHandler<>(ValidateForecasterAction.INSTANCE, ValidateForecasterTransportAction.class), + new ActionHandler<>(SuggestForecasterParamAction.INSTANCE, SuggestForecasterParamTransportAction.class) ); } @@ -1055,7 +1686,7 @@ public String getJobIndex() { @Override public ScheduledJobRunner getJobRunner() { - return AnomalyDetectorJobRunner.getJobRunnerInstance(); + return JobRunner.getJobRunnerInstance(); } @Override diff --git a/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java index efa48ec7f..dd5ed15c8 100644 --- a/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java +++ b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java @@ -16,8 +16,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.timeseries.settings.TimeSeriesEnabledSetting; /** * Class {@code CircuitBreakerService} provide storing, retrieving circuit breakers functions. @@ -76,7 +76,7 @@ public CircuitBreakerService init() { } public Boolean isOpen() { - if (!ADEnabledSetting.isADBreakerEnabled()) { + if (!TimeSeriesEnabledSetting.isBreakerEnabled()) { return false; } diff --git a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java b/src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java similarity index 73% rename from src/main/java/org/opensearch/ad/caching/CacheBuffer.java rename to src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java index fb48fd273..8d5605816 100644 --- a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java +++ b/src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; @@ -25,273 +25,150 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.util.DateUtils; import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.util.DateUtils; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class CacheBuffer & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker> + implements + ExpiringState { -/** - * We use a layered cache to manage active entities’ states. We have a two-level - * cache that stores active entity states in each node. Each detector has its - * dedicated cache that stores ten (dynamically adjustable) entities’ states per - * node. A detector’s hottest entities load their states in the dedicated cache. - * If less than 10 entities use the dedicated cache, the secondary cache can use - * the rest of the free memory available to AD. The secondary cache is a shared - * memory among all detectors for the long tail. The shared cache size is 10% - * heap minus all of the dedicated cache consumed by single-entity and multi-entity - * detectors. The shared cache’s size shrinks as the dedicated cache is filled - * up or more detectors are started. - * - * Implementation-wise, both dedicated cache and shared cache are stored in items - * and minimumCapacity controls the boundary. If items size is equals to or less - * than minimumCapacity, consider items as dedicated cache; otherwise, consider - * top minimumCapacity active entities (last X entities in priorityList) as in dedicated - * cache and all others in shared cache. - */ -public class CacheBuffer implements ExpiringState { private static final Logger LOG = LogManager.getLogger(CacheBuffer.class); - // max entities to track per detector - private final int MAX_TRACKING_ENTITIES = 1000000; + protected Instant lastUsedTime; + protected final Clock clock; + + protected final MemoryTracker memoryTracker; + protected int checkpointIntervalHrs; + protected final Duration modelTtl; + // max entities to track per detector + protected final int MAX_TRACKING_ENTITIES = 1000000; // the reserved cache size. So no matter how many entities there are, we will // keep the size for minimum capacity entities - private int minimumCapacity; - - // key is model id - private final ConcurrentHashMap> items; + protected int minimumCapacity; // memory consumption per entity - private final long memoryConsumptionPerEntity; - private final MemoryTracker memoryTracker; - private final Duration modelTtl; - private final String detectorId; - private Instant lastUsedTime; - private long reservedBytes; - private final PriorityTracker priorityTracker; - private final Clock clock; - private final CheckpointWriteWorker checkpointWriteQueue; - private final CheckpointMaintainWorker checkpointMaintainQueue; - private int checkpointIntervalHrs; + protected final long memoryConsumptionPerModel; + protected long reservedBytes; + protected final CheckpointWriterType checkpointWriteQueue; + protected final CheckpointMaintainerType checkpointMaintainQueue; + protected final String configId; + protected final Origin origin; + protected final PriorityTracker priorityTracker; + // key is model id + protected final ConcurrentHashMap> items; public CacheBuffer( int minimumCapacity, - long intervalSecs, - long memoryConsumptionPerEntity, - MemoryTracker memoryTracker, Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, Duration modelTtl, - String detectorId, - CheckpointWriteWorker checkpointWriteQueue, - CheckpointMaintainWorker checkpointMaintainQueue, - int checkpointIntervalHrs + long memoryConsumptionPerEntity, + CheckpointWriterType checkpointWriteQueue, + CheckpointMaintainerType checkpointMaintainQueue, + String configId, + long intervalSecs, + Origin origin ) { - this.memoryConsumptionPerEntity = memoryConsumptionPerEntity; - setMinimumCapacity(minimumCapacity); - - this.items = new ConcurrentHashMap<>(); - this.memoryTracker = memoryTracker; - - this.modelTtl = modelTtl; - this.detectorId = detectorId; this.lastUsedTime = clock.instant(); - this.clock = clock; - this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.memoryTracker = memoryTracker; + setCheckpointIntervalHrs(checkpointIntervalHrs); + this.modelTtl = modelTtl; + this.memoryConsumptionPerModel = memoryConsumptionPerEntity; this.checkpointWriteQueue = checkpointWriteQueue; this.checkpointMaintainQueue = checkpointMaintainQueue; - setCheckpointIntervalHrs(checkpointIntervalHrs); - } - - /** - * Update step at period t_k: - * new priority = old priority + log(1+e^{\log(g(t_k-L))-old priority}) where g(n) = e^{0.125n}, - * and n is the period. - * @param entityModelId model Id - */ - private void update(String entityModelId) { - priorityTracker.updatePriority(entityModelId); - - Instant now = clock.instant(); - items.get(entityModelId).setLastUsedTime(now); - lastUsedTime = now; - } - - /** - * Insert the model state associated with a model Id to the cache - * @param entityModelId the model Id - * @param value the ModelState - */ - public void put(String entityModelId, ModelState value) { - // race conditions can happen between the put and one of the following operations: - // remove: not a problem as it is unlikely we are removing and putting the same thing - // maintenance: not a problem as we are unlikely to maintain an entry that's not - // already in the cache - // clear: not a problem as we are releasing memory in MemoryTracker. - // The newly added one loses references and soon GC will collect it. - // We have memory tracking correction to fix incorrect memory usage record. - // put from other threads: not a problem as the entry is associated with - // entityModelId and our put is idempotent - put(entityModelId, value, value.getPriority()); + this.configId = configId; + this.origin = origin; + this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.items = new ConcurrentHashMap<>(); + // called after minimumCapacity and memoryConsumptionPerModel are set + setMinimumCapacity(minimumCapacity); } - /** - * Insert the model state associated with a model Id to the cache. Update priority. - * @param entityModelId the model Id - * @param value the ModelState - * @param priority the priority - */ - private void put(String entityModelId, ModelState value, float priority) { - ModelState contentNode = items.get(entityModelId); - if (contentNode == null) { - priorityTracker.addPriority(entityModelId, priority); - items.put(entityModelId, value); - Instant now = clock.instant(); - value.setLastUsedTime(now); - lastUsedTime = now; - // shared cache empty means we are consuming reserved cache. - // Since we have already considered them while allocating CacheBuffer, - // skip bookkeeping. - if (!sharedCacheEmpty()) { - memoryTracker.consumeMemory(memoryConsumptionPerEntity, false, Origin.REAL_TIME_DETECTOR); - } - } else { - update(entityModelId); - items.put(entityModelId, value); + public void setMinimumCapacity(int minimumCapacity) { + if (minimumCapacity < 0) { + throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); } + this.minimumCapacity = minimumCapacity; + this.reservedBytes = memoryConsumptionPerModel * minimumCapacity; } - /** - * Retrieve the ModelState associated with the model Id or null if the CacheBuffer - * contains no mapping for the model Id - * @param key the model Id - * @return the Model state to which the specified model Id is mapped, or null - * if this CacheBuffer contains no mapping for the model Id - */ - public ModelState get(String key) { - // We can get an item that is to be removed soon due to race condition. - // This is acceptable as it won't cause any corruption and exception. - // And this item is used for scoring one last time. - ModelState node = items.get(key); - if (node == null) { - return null; - } - update(key); - return node; + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); } - /** - * Retrieve the ModelState associated with the model Id or null if the CacheBuffer - * contains no mapping for the model Id. Compared to get method, the method won't - * increment entity priority. Used in cache buffer maintenance. - * - * @param key the model Id - * @return the Model state to which the specified model Id is mapped, or null - * if this CacheBuffer contains no mapping for the model Id - */ - public ModelState getWithoutUpdatePriority(String key) { - // We can get an item that is to be removed soon due to race condition. - // This is acceptable as it won't cause any corruption and exception. - // And this item is used for scoring one last time. - ModelState node = items.get(key); - if (node == null) { - return null; + public void setCheckpointIntervalHrs(int checkpointIntervalHrs) { + this.checkpointIntervalHrs = checkpointIntervalHrs; + // 0 can cause java.lang.ArithmeticException: / by zero + // negative value is meaningless + if (checkpointIntervalHrs <= 0) { + this.checkpointIntervalHrs = 1; } - return node; } - /** - * - * @return whether there is one item that can be removed from shared cache - */ - public boolean canRemove() { - return !items.isEmpty() && items.size() > minimumCapacity; + public int getCheckpointIntervalHrs() { + return checkpointIntervalHrs; } /** - * remove the smallest priority item. - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove() { - // race conditions can happen between the put and one of the following operations: - // remove from other threads: not a problem. If they remove the same item, - // our method is idempotent. If they remove two different items, - // they don't impact each other. - // maintenance: not a problem as all of the data structures are concurrent. - // Two threads removing the same entry is not a problem. - // clear: not a problem as we are releasing memory in MemoryTracker. - // The removed one loses references and soon GC will collect it. - // We have memory tracking correction to fix incorrect memory usage record. - // put: not a problem as it is unlikely we are removing and putting the same thing - Optional key = priorityTracker.getMinimumPriorityEntityId(); - if (key.isPresent()) { - return remove(key.get()); - } - return null; + * + * @return reserved bytes by the CacheBuffer + */ + public long getReservedBytes() { + return reservedBytes; } /** - * Remove everything associated with the key and make a checkpoint. - * - * @param keyToRemove The key to remove - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove(String keyToRemove) { - return remove(keyToRemove, true); + * + * @return the estimated number of bytes per entity state + */ + public long getMemoryConsumptionPerModel() { + return memoryConsumptionPerModel; } - /** - * Remove everything associated with the key and make a checkpoint if input specified so. - * - * @param keyToRemove The key to remove - * @param saveCheckpoint Whether saving checkpoint or not - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove(String keyToRemove, boolean saveCheckpoint) { - priorityTracker.removePriority(keyToRemove); - - // if shared cache is empty, we are using reserved memory - boolean reserved = sharedCacheEmpty(); - - ModelState valueRemoved = items.remove(keyToRemove); + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; - if (valueRemoved != null) { - if (!reserved) { - // release in shared memory - memoryTracker.releaseMemory(memoryConsumptionPerEntity, false, Origin.REAL_TIME_DETECTOR); - } + if (obj instanceof CacheBuffer) { + @SuppressWarnings("unchecked") + CacheBuffer other = + (CacheBuffer) obj; - EntityModel modelRemoved = valueRemoved.getModel(); - if (modelRemoved != null) { - if (saveCheckpoint) { - // null model has only samples. For null model we save a checkpoint - // regardless of last checkpoint time. whether If we don't save, - // we throw the new samples and might never be able to initialize the model - boolean isNullModel = !modelRemoved.getTrcf().isPresent(); - checkpointWriteQueue.write(valueRemoved, isNullModel, RequestPriority.MEDIUM); - } + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(configId, other.configId); - modelRemoved.clear(); - } + return equalsBuilder.isEquals(); } + return false; + } - return valueRemoved; + @Override + public int hashCode() { + return new HashCodeBuilder().append(configId).toHashCode(); } - /** - * @return whether dedicated cache is available or not - */ - public boolean dedicatedCacheAvailable() { - return items.size() < minimumCapacity; + public String getConfigId() { + return configId; } /** @@ -302,56 +179,47 @@ public boolean sharedCacheEmpty() { } /** - * - * @return the estimated number of bytes per entity state - */ - public long getMemoryConsumptionPerEntity() { - return memoryConsumptionPerEntity; - } - - /** - * - * If the cache is not full, check if some other items can replace internal entities - * within the same detector. - * - * @param priority another entity's priority - * @return whether one entity can be replaced by another entity with a certain priority - */ - public boolean canReplaceWithinDetector(float priority) { - if (items.isEmpty()) { - return false; + * + * @return bytes consumed in the shared cache by the CacheBuffer + */ + public long getBytesInSharedCache() { + int sharedCacheEntries = items.size() - minimumCapacity; + if (sharedCacheEntries > 0) { + return memoryConsumptionPerModel * sharedCacheEntries; } - Optional> minPriorityItem = priorityTracker.getMinimumPriority(); - return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); + return 0; } /** - * Replace the smallest priority entity with the input entity - * @param entityModelId the Model Id - * @param value the model State - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key + * Clear associated memory. Used when we are removing an detector. */ - public ModelState replace(String entityModelId, ModelState value) { - ModelState replaced = remove(); - put(entityModelId, value); - return replaced; + public void clear() { + // race conditions can happen between the put and remove/maintenance/put: + // not a problem as we are releasing memory in MemoryTracker. + // The newly added one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + memoryTracker.releaseMemory(getReservedBytes(), true, origin); + if (!sharedCacheEmpty()) { + memoryTracker.releaseMemory(getBytesInSharedCache(), false, origin); + } + items.clear(); + priorityTracker.clearPriority(); } /** * Remove expired state and save checkpoints of existing states * @return removed states */ - public List> maintenance() { + public List> maintenance() { List modelsToSave = new ArrayList<>(); - List> removedStates = new ArrayList<>(); + List> removedStates = new ArrayList<>(); Instant now = clock.instant(); int currentHour = DateUtils.getUTCHourOfDay(now); int currentSlot = currentHour % checkpointIntervalHrs; items.entrySet().stream().forEach(entry -> { String entityModelId = entry.getKey(); try { - ModelState modelState = entry.getValue(); + ModelState modelState = entry.getValue(); if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) { // race conditions can happen between the put and one of the following operations: @@ -397,7 +265,7 @@ public List> maintenance() { new CheckpointMaintainRequest( // the request expires when the next maintainance starts System.currentTimeMillis() + modelTtl.toMillis(), - detectorId, + configId, RequestPriority.LOW, entityModelId ) @@ -414,9 +282,97 @@ public List> maintenance() { } /** + * Remove everything associated with the key and make a checkpoint if input specified so. + * + * @param keyToRemove The key to remove + * @param saveCheckpoint Whether saving checkpoint or not + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove(String keyToRemove, boolean saveCheckpoint) { + priorityTracker.removePriority(keyToRemove); + + // if shared cache is empty, we are using reserved memory + boolean reserved = sharedCacheEmpty(); + + ModelState valueRemoved = items.remove(keyToRemove); + + if (valueRemoved != null) { + if (!reserved) { + // release in shared memory + memoryTracker.releaseMemory(memoryConsumptionPerModel, false, origin); + } + + if (saveCheckpoint) { + // null model has only samples. For null model we save a checkpoint + // regardless of last checkpoint time. whether If we don't save, + // we throw the new samples and might never be able to initialize the model + checkpointWriteQueue.write(valueRemoved, valueRemoved.getModel().isEmpty(), RequestPriority.MEDIUM); + } + + valueRemoved.clear(); + } + + return valueRemoved; + } + + /** + * Remove everything associated with the key and make a checkpoint. * - * @return the number of active entities + * @param keyToRemove The key to remove + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key */ + public ModelState remove(String keyToRemove) { + return remove(keyToRemove, true); + } + + public PriorityTracker getPriorityTracker() { + return priorityTracker; + } + + /** + * remove the smallest priority item. + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove() { + // race conditions can happen between the put and one of the following operations: + // remove from other threads: not a problem. If they remove the same item, + // our method is idempotent. If they remove two different items, + // they don't impact each other. + // maintenance: not a problem as all of the data structures are concurrent. + // Two threads removing the same entry is not a problem. + // clear: not a problem as we are releasing memory in MemoryTracker. + // The removed one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + // put: not a problem as it is unlikely we are removing and putting the same thing + Optional key = priorityTracker.getMinimumPriorityEntityId(); + if (key.isPresent()) { + return remove(key.get()); + } + return null; + } + + /** + * + * @return whether there is one item that can be removed from shared cache + */ + public boolean canRemove() { + return !items.isEmpty() && items.size() > minimumCapacity; + } + + /** + * @return whether dedicated cache is available or not + */ + public boolean dedicatedCacheAvailable() { + return items.size() < minimumCapacity; + } + + /** + * + * @return the number of active entities + */ public int getActiveEntities() { return items.size(); } @@ -436,7 +392,7 @@ public boolean isActive(String entityModelId) { * @return Last used time of the model */ public long getLastUsedTime(String entityModelId) { - ModelState state = items.get(entityModelId); + ModelState state = items.get(entityModelId); if (state != null) { return state.getLastUsedTime().toEpochMilli(); } @@ -448,105 +404,139 @@ public long getLastUsedTime(String entityModelId) { * @param entityModelId entity Id * @return Get the model of an entity */ - public Optional getModel(String entityModelId) { - return Optional.of(items).map(map -> map.get(entityModelId)).map(state -> state.getModel()); + public ModelState getModelState(String entityModelId) { + // flatMap allows for mapping the inner Optional directly, which results in + // a single Optional instead of a nested Optional>. + return items.get(entityModelId); } /** - * Clear associated memory. Used when we are removing an detector. + * Update step at period t_k: + * new priority = old priority + log(1+e^{\log(g(t_k-L))-old priority}) where g(n) = e^{0.125n}, + * and n is the period. + * @param entityModelId model Id */ - public void clear() { - // race conditions can happen between the put and remove/maintenance/put: - // not a problem as we are releasing memory in MemoryTracker. + private void update(String entityModelId) { + priorityTracker.updatePriority(entityModelId); + + Instant now = clock.instant(); + items.get(entityModelId).setLastUsedTime(now); + lastUsedTime = now; + } + + /** + * Insert the model state associated with a model Id to the cache + * @param entityModelId the model Id + * @param value the ModelState + */ + public void put(String entityModelId, ModelState value) { + // race conditions can happen between the put and one of the following operations: + // remove: not a problem as it is unlikely we are removing and putting the same thing + // maintenance: not a problem as we are unlikely to maintain an entry that's not + // already in the cache + // clear: not a problem as we are releasing memory in MemoryTracker. // The newly added one loses references and soon GC will collect it. // We have memory tracking correction to fix incorrect memory usage record. - memoryTracker.releaseMemory(getReservedBytes(), true, Origin.REAL_TIME_DETECTOR); - if (!sharedCacheEmpty()) { - memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.REAL_TIME_DETECTOR); + // put from other threads: not a problem as the entry is associated with + // entityModelId and our put is idempotent + put(entityModelId, value, value.getPriority()); + } + + /** + * Insert the model state associated with a model Id to the cache. Update priority. + * @param entityModelId the model Id + * @param value the ModelState + * @param priority the priority + */ + private void put(String entityModelId, ModelState value, float priority) { + ModelState contentNode = items.get(entityModelId); + if (contentNode == null) { + priorityTracker.addPriority(entityModelId, priority); + items.put(entityModelId, value); + Instant now = clock.instant(); + value.setLastUsedTime(now); + lastUsedTime = now; + // shared cache empty means we are consuming reserved cache. + // Since we have already considered them while allocating CacheBuffer, + // skip bookkeeping. + if (!sharedCacheEmpty()) { + memoryTracker.consumeMemory(memoryConsumptionPerModel, false, origin); + } + } else { + update(entityModelId); + items.put(entityModelId, value); } - items.clear(); - priorityTracker.clearPriority(); } /** - * - * @return reserved bytes by the CacheBuffer + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id */ - public long getReservedBytes() { - return reservedBytes; + public ModelState get(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; + } + update(key); + return node; } /** + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id. Compared to get method, the method won't + * increment entity priority. Used in cache buffer maintenance. * - * @return bytes consumed in the shared cache by the CacheBuffer + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id */ - public long getBytesInSharedCache() { - int sharedCacheEntries = items.size() - minimumCapacity; - if (sharedCacheEntries > 0) { - return memoryConsumptionPerEntity * sharedCacheEntries; + public ModelState getWithoutUpdatePriority(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; } - return 0; + return node; } - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) + /** + * + * If the cache is not full, check if some other items can replace internal entities + * within the same config. + * + * @param priority another entity's priority + * @return whether one entity can be replaced by another entity with a certain priority + */ + public boolean canReplaceWithinConfig(float priority) { + if (items.isEmpty()) { return false; - if (obj instanceof InitProgressProfile) { - CacheBuffer other = (CacheBuffer) obj; - - EqualsBuilder equalsBuilder = new EqualsBuilder(); - equalsBuilder.append(detectorId, other.detectorId); - - return equalsBuilder.isEquals(); } - return false; - } - - @Override - public int hashCode() { - return new HashCodeBuilder().append(detectorId).toHashCode(); - } - - @Override - public boolean expired(Duration stateTtl) { - return expired(lastUsedTime, stateTtl, clock.instant()); + Optional> minPriorityItem = priorityTracker.getMinimumPriority(); + return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); } - public String getId() { - return detectorId; + /** + * Replace the smallest priority entity with the input entity + * @param entityModelId the Model Id + * @param value the model State + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState replace(String entityModelId, ModelState value) { + ModelState replaced = remove(); + put(entityModelId, value); + return replaced; } - public List> getAllModels() { + public List> getAllModelStates() { return items.values().stream().collect(Collectors.toList()); } - - public PriorityTracker getPriorityTracker() { - return priorityTracker; - } - - public void setMinimumCapacity(int minimumCapacity) { - if (minimumCapacity < 0) { - throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); - } - this.minimumCapacity = minimumCapacity; - this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity; - } - - public void setCheckpointIntervalHrs(int checkpointIntervalHrs) { - this.checkpointIntervalHrs = checkpointIntervalHrs; - // 0 can cause java.lang.ArithmeticException: / by zero - // negative value is meaningless - if (checkpointIntervalHrs <= 0) { - this.checkpointIntervalHrs = 1; - } - } - - public int getCheckpointIntervalHrs() { - return checkpointIntervalHrs; - } } diff --git a/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java b/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java new file mode 100644 index 000000000..9b4a53705 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.caching; + +import org.opensearch.common.inject.Provider; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A wrapper to call concrete implementation of caching. Used in transport + * action. Don't use interface because transport action handler constructor + * requires a concrete class as input. + * + */ +public class CacheProvider> + implements + Provider { + private CacheType cache; + + public CacheProvider() { + + } + + @Override + public CacheType get() { + return cache; + } + + public void set(CacheType cache) { + this.cache = cache; + } +} diff --git a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java b/src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java similarity index 60% rename from src/main/java/org/opensearch/ad/caching/DoorKeeper.java rename to src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java index 5bb5e3cd5..488b5b7b4 100644 --- a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java +++ b/src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java @@ -9,57 +9,51 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.MaintenanceState; -import com.google.common.base.Charsets; -import com.google.common.hash.BloomFilter; -import com.google.common.hash.Funnels; - /** - * A bloom filter with regular reset. + * A hashmap thats track the exact frequency of each element and reset regularly. * - * Reference: https://arxiv.org/abs/1512.00727 + * The name of door keeper derives from https://arxiv.org/abs/1512.00727 * */ public class DoorKeeper implements MaintenanceState, ExpiringState { private final Logger LOG = LogManager.getLogger(DoorKeeper.class); // stores entity's model id - private BloomFilter bloomFilter; - // the number of expected insertions to the constructed BloomFilter; must be positive private final long expectedInsertions; - // the desired false positive probability (must be positive and less than 1.0) - private final double fpp; + private Map frequencyMap; private Instant lastMaintenanceTime; private final Duration resetInterval; private final Clock clock; private Instant lastAccessTime; + private final int countThreshold; - public DoorKeeper(long expectedInsertions, double fpp, Duration resetInterval, Clock clock) { + public DoorKeeper(long expectedInsertions, Duration resetInterval, Clock clock, int countThreshold) { this.expectedInsertions = expectedInsertions; - this.fpp = fpp; this.resetInterval = resetInterval; this.clock = clock; + this.countThreshold = countThreshold; this.lastAccessTime = clock.instant(); maintenance(); } - public boolean mightContain(String modelId) { + public void put(String modelId) { this.lastAccessTime = clock.instant(); - return bloomFilter.mightContain(modelId); - } - - public boolean put(String modelId) { - this.lastAccessTime = clock.instant(); - return bloomFilter.put(modelId); + this.frequencyMap.put(modelId, this.frequencyMap.getOrDefault(modelId, 0) + 1); + if (frequencyMap.size() > expectedInsertions) { + reset(); + } } /** @@ -67,13 +61,22 @@ public boolean put(String modelId) { */ @Override public void maintenance() { - if (bloomFilter == null || lastMaintenanceTime.plus(resetInterval).isBefore(clock.instant())) { + if (frequencyMap == null || lastMaintenanceTime.plus(resetInterval).isBefore(clock.instant())) { LOG.debug("maintaining for doorkeeper"); - bloomFilter = BloomFilter.create(Funnels.stringFunnel(Charsets.US_ASCII), expectedInsertions, fpp); - lastMaintenanceTime = clock.instant(); + reset(); } } + private void reset() { + frequencyMap = new HashMap<>(); + lastMaintenanceTime = clock.instant(); + } + + public boolean appearsMoreThanThreshold(String item) { + this.lastAccessTime = clock.instant(); + return this.frequencyMap.getOrDefault(item, 0) > countThreshold; + } + @Override public boolean expired(Duration stateTtl) { // ignore stateTtl since we have customized resetInterval diff --git a/src/main/java/org/opensearch/ad/caching/PriorityCache.java b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java similarity index 64% rename from src/main/java/org/opensearch/ad/caching/PriorityCache.java rename to src/main/java/org/opensearch/timeseries/caching/PriorityCache.java index 40e28975d..300509f54 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java @@ -9,16 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -32,22 +30,13 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Triple; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.settings.ADEnabledSetting; -import org.opensearch.ad.util.DateUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -57,47 +46,59 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.MemoryTracker.Origin; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DateUtils; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -public class PriorityCache implements EntityCache { - private final Logger LOG = LogManager.getLogger(PriorityCache.class); +public abstract class PriorityCache & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker, CacheBufferType extends CacheBuffer> + implements + TimeSeriesCache { + + private static final Logger LOG = LogManager.getLogger(PriorityCache.class); // detector id -> CacheBuffer, weight based - private final Map activeEnities; - private final CheckpointDao checkpointDao; - private volatile int dedicatedCacheSize; + private final Map activeEnities; + private final CheckpointDaoType checkpointDao; + protected volatile int hcDedicatedCacheSize; // LRU Cache, key is model id - private Cache> inActiveEntities; - private final MemoryTracker memoryTracker; + private Cache> inActiveEntities; + protected final MemoryTracker memoryTracker; private final ReentrantLock maintenanceLock; private final int numberOfTrees; - private final Clock clock; - private final Duration modelTtl; + protected final Clock clock; + protected final Duration modelTtl; // A bloom filter placed in front of inactive entity cache to // filter out unpopular items that are not likely to appear more // than once. Key is detector id private Map doorKeepers; private ThreadPool threadPool; + private String threadPoolName; private Random random; - private CheckpointWriteWorker checkpointWriteQueue; // iterating through all of inactive entities is heavy. We don't want to do // it again and again for no obvious benefits. private Instant lastInActiveEntityMaintenance; protected int maintenanceFreqConstant; - private CheckpointMaintainWorker checkpointMaintainQueue; - private int checkpointIntervalHrs; + protected int checkpointIntervalHrs; + private Origin origin; public PriorityCache( - CheckpointDao checkpointDao, - int dedicatedCacheSize, + CheckpointDaoType checkpointDao, + int hcDedicatedCacheSize, Setting checkpointTtl, int maxInactiveStates, MemoryTracker memoryTracker, @@ -106,22 +107,24 @@ public PriorityCache( ClusterService clusterService, Duration modelTtl, ThreadPool threadPool, - CheckpointWriteWorker checkpointWriteQueue, + String threadPoolName, int maintenanceFreqConstant, - CheckpointMaintainWorker checkpointMaintainQueue, Settings settings, - Setting checkpointSavingFreq + Setting checkpointSavingFreq, + Origin origin, + Setting dedicatedCacheSizeSetting, + Setting modelMaxSizePercent ) { this.checkpointDao = checkpointDao; this.activeEnities = new ConcurrentHashMap<>(); - this.dedicatedCacheSize = dedicatedCacheSize; - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_DEDICATED_CACHE_SIZE, (it) -> { - this.dedicatedCacheSize = it; - this.setDedicatedCacheSizeListener(); + this.hcDedicatedCacheSize = hcDedicatedCacheSize; + clusterService.getClusterSettings().addSettingsUpdateConsumer(dedicatedCacheSizeSetting, (it) -> { + this.hcDedicatedCacheSize = it; + this.setHCDedicatedCacheSizeListener(); this.tryClearUpMemory(); }, this::validateDedicatedCacheSize); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MODEL_MAX_SIZE_PERCENTAGE, it -> this.tryClearUpMemory()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(modelMaxSizePercent, it -> this.tryClearUpMemory()); this.memoryTracker = memoryTracker; this.maintenanceLock = new ReentrantLock(); @@ -138,56 +141,58 @@ public PriorityCache( }); this.threadPool = threadPool; + this.threadPoolName = threadPoolName; this.random = new Random(42); - this.checkpointWriteQueue = checkpointWriteQueue; this.lastInActiveEntityMaintenance = Instant.MIN; this.maintenanceFreqConstant = maintenanceFreqConstant; - this.checkpointMaintainQueue = checkpointMaintainQueue; this.checkpointIntervalHrs = DateUtils.toDuration(checkpointSavingFreq.get(settings)).toHoursPart(); clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointSavingFreq, it -> { this.checkpointIntervalHrs = DateUtils.toDuration(it).toHoursPart(); this.setCheckpointFreqListener(); }); + this.origin = origin; } @Override - public ModelState get(String modelId, AnomalyDetector detector) { - String detectorId = detector.getId(); - CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); - ModelState modelState = buffer.get(modelId); + public ModelState get(String modelId, Config config) { + String configId = config.getId(); + CacheBufferType buffer = activeEnities.get(configId); + ModelState modelState = null; + if (buffer != null) { + modelState = buffer.get(modelId); + } // during maintenance period, stop putting new entries if (!maintenanceLock.isLocked() && modelState == null) { - if (ADEnabledSetting.isDoorKeeperInCacheEnabled()) { - DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(detectorId, id -> { - // reset every 60 intervals - return new DoorKeeper( - TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, - TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, - detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), - clock - ); - }); + DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(configId, id -> { + // reset every 60 intervals + return new DoorKeeper( + TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock, + TimeSeriesSettings.DOOR_KEEPER_COUNT_THRESHOLD + ); + }); - // first hit, ignore - // since door keeper may get reset during maintenance, it is possible - // the entity is still active even though door keeper has no record of - // this model Id. We have to call isActive method to make sure. Otherwise, - // the entity might miss an anomaly result every 60 intervals due to door keeper - // reset. - if (!doorKeeper.mightContain(modelId) && !isActive(detectorId, modelId)) { - doorKeeper.put(modelId); - return null; - } + // first few hits, ignore + // since door keeper may get reset during maintenance, it is possible + // the entity is still active even though door keeper has no record of + // this model Id. We have to call isActive method to make sure. Otherwise, + // the entity might miss a result every 60 intervals due to door keeper + // reset. + if (!doorKeeper.appearsMoreThanThreshold(modelId) && !isActive(configId, modelId)) { + doorKeeper.put(modelId); + return null; } try { - ModelState state = inActiveEntities.get(modelId, new Callable>() { + ModelState state = inActiveEntities.get(modelId, new Callable>() { @Override - public ModelState call() { - return new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + public ModelState call() { + return createEmptyModelState(modelId, configId); } + }); // make sure no model has been stored due to previous race conditions @@ -214,7 +219,9 @@ public ModelState call() { // intervals. // update state using new priority or create a new one - state.setPriority(buffer.getPriorityTracker().getUpdatedPriority(state.getPriority())); + if (buffer != null) { + state.setPriority(buffer.getPriorityTracker().getUpdatedPriority(state.getPriority())); + } // adjust shared memory in case we have used dedicated cache memory for other detectors if (random.nextInt(maintenanceFreqConstant) == 1) { @@ -229,57 +236,58 @@ public ModelState call() { return modelState; } - private Optional> getStateFromInactiveEntiiyCache(String modelId) { + private Optional> getStateFromInactiveEntiiyCache(String modelId) { if (modelId == null) { return Optional.empty(); } - // null if not even recorded in inActiveEntities yet because of doorKeeper + // null if not even recorded in inActiveEntities yet because of doorKeeper or first time start config return Optional.ofNullable(inActiveEntities.getIfPresent(modelId)); } @Override - public boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate) { - if (toUpdate == null) { + public boolean hostIfPossible(Config config, ModelState toUpdate) { + if (toUpdate == null || toUpdate.getModel() == null) { return false; } String modelId = toUpdate.getModelId(); - String detectorId = toUpdate.getId(); + String configId = toUpdate.getConfigId(); - if (Strings.isEmpty(modelId) || Strings.isEmpty(detectorId)) { + if (Strings.isEmpty(modelId) || Strings.isEmpty(configId)) { return false; } - CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); + CacheBufferType buffer = computeBufferIfAbsent(config, configId); - Optional> state = getStateFromInactiveEntiiyCache(modelId); - if (false == state.isPresent()) { - return false; + Optional> state = getStateFromInactiveEntiiyCache(modelId); + ModelState modelState = null; + if (state.isPresent()) { + modelState = state.get(); + } else { + modelState = createEmptyModelState(modelId, configId); } - ModelState modelState = state.get(); - float priority = modelState.getPriority(); toUpdate.setLastUsedTime(clock.instant()); toUpdate.setPriority(priority); // current buffer's dedicated cache has free slots or can allocate in shared cache - if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // buffer.put will call MemoryTracker.consumeMemory buffer.put(modelId, toUpdate); return true; } - if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // buffer.put will call MemoryTracker.consumeMemory buffer.put(modelId, toUpdate); return true; } // can replace an entity in the same CacheBuffer living in reserved or shared cache - if (buffer.canReplaceWithinDetector(priority)) { - ModelState removed = buffer.replace(modelId, toUpdate); + if (buffer.canReplaceWithinConfig(priority)) { + ModelState removed = buffer.replace(modelId, toUpdate); // null in the case of some other threads have emptied the queue at // the same time so there is nothing to replace if (removed != null) { @@ -291,10 +299,10 @@ public boolean hostIfPossible(AnomalyDetector detector, ModelState // If two threads try to remove the same entity and add their own state, the 2nd remove // returns null and only the first one succeeds. float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); - Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); - CacheBuffer bufferToRemove = bufferToRemoveEntity.getLeft(); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + CacheBufferType bufferToRemove = bufferToRemoveEntity.getLeft(); String entityModelId = bufferToRemoveEntity.getMiddle(); - ModelState removed = null; + ModelState removed = null; if (bufferToRemove != null && ((removed = bufferToRemove.remove(entityModelId)) != null)) { buffer.put(modelId, toUpdate); addIntoInactiveCache(removed); @@ -304,7 +312,7 @@ public boolean hostIfPossible(AnomalyDetector detector, ModelState return false; } - private void addIntoInactiveCache(ModelState removed) { + private void addIntoInactiveCache(ModelState removed) { if (removed == null) { return; } @@ -314,10 +322,10 @@ private void addIntoInactiveCache(ModelState removed) { inActiveEntities.put(removed.getModelId(), removed); } - private void addEntity(List destination, Entity entity, String detectorId) { + private void addEntity(List destination, Entity entity, String configId) { // It's possible our doorkeepr prevented the entity from entering inactive entities cache if (entity != null) { - Optional modelId = entity.getModelId(detectorId); + Optional modelId = entity.getModelId(configId); if (modelId.isPresent() && inActiveEntities.getIfPresent(modelId.get()) != null) { destination.add(entity); } @@ -325,38 +333,31 @@ private void addEntity(List destination, Entity entity, String detectorI } @Override - public Pair, List> selectUpdateCandidate( - Collection cacheMissEntities, - String detectorId, - AnomalyDetector detector - ) { + public Pair, List> selectUpdateCandidate(Collection cacheMissEntities, String configId, Config config) { List hotEntities = new ArrayList<>(); List coldEntities = new ArrayList<>(); - CacheBuffer buffer = activeEnities.get(detectorId); + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { - // don't want to create side-effects by creating a CacheBuffer - // In current implementation, this branch is impossible as we call - // PriorityCache.get method before invoking this method. The - // PriorityCache.get method creates a CacheBuffer if not present. - // Since this method is public, need to deal with this case in case of misuse. - return Pair.of(hotEntities, coldEntities); + // when a config is just started or during run once, there is + // no cache buffer yet. Make every cache miss entities hot + return Pair.of(new ArrayList<>(cacheMissEntities), coldEntities); } Iterator cacheMissEntitiesIter = cacheMissEntities.iterator(); // current buffer's dedicated cache has free slots while (cacheMissEntitiesIter.hasNext() && buffer.dedicatedCacheAvailable()) { - addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + addEntity(hotEntities, cacheMissEntitiesIter.next(), configId); } - while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // can allocate in shared cache // race conditions can happen when multiple threads evaluating this condition. // This is a problem as our AD memory usage is close to full and we put // more things than we planned. One model in HCAD is small, // it is fine we exceed a little. We have regular maintenance to remove // extra memory usage. - addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + addEntity(hotEntities, cacheMissEntitiesIter.next(), configId); } // check if we can replace anything in dedicated or shared cache @@ -370,23 +371,23 @@ public Pair, List> selectUpdateCandidate( // thread safe as each detector has one thread at one time and only the // thread can access its buffer. Entity entity = cacheMissEntitiesIter.next(); - Optional modelId = entity.getModelId(detectorId); + Optional modelId = entity.getModelId(configId); if (false == modelId.isPresent()) { continue; } - Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); + Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); if (false == state.isPresent()) { // not even recorded in inActiveEntities yet because of doorKeeper continue; } - ModelState modelState = state.get(); + ModelState modelState = state.get(); float priority = modelState.getPriority(); - if (buffer.canReplaceWithinDetector(priority)) { - addEntity(hotEntities, entity, detectorId); + if (buffer.canReplaceWithinConfig(priority)) { + addEntity(hotEntities, entity, configId); } else { // re-evaluate replacement condition in other buffers otherBufferReplaceCandidates.add(entity); @@ -395,7 +396,7 @@ public Pair, List> selectUpdateCandidate( // record current minimum priority among all detectors to save redundant // scanning of all CacheBuffers - CacheBuffer bufferToRemove = null; + CacheBufferType bufferToRemove = null; float minPriority = Float.MIN_VALUE; // check if we can replace in other CacheBuffer @@ -405,77 +406,64 @@ public Pair, List> selectUpdateCandidate( // If two threads try to remove the same entity and add their own state, the 2nd remove // returns null and only the first one succeeds. Entity entity = cacheMissEntitiesIter.next(); - Optional modelId = entity.getModelId(detectorId); + Optional modelId = entity.getModelId(configId); if (false == modelId.isPresent()) { continue; } - Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); + Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); if (false == inactiveState.isPresent()) { // empty state should not stand a chance to replace others continue; } - ModelState state = inactiveState.get(); + ModelState state = inactiveState.get(); float priority = state.getPriority(); float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); if (scaledPriority <= minPriority) { // not even larger than the minPriority, we can put this to coldEntities - addEntity(coldEntities, entity, detectorId); + addEntity(coldEntities, entity, configId); continue; } // Float.MIN_VALUE means we need to re-iterate through all CacheBuffers if (minPriority == Float.MIN_VALUE) { - Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); bufferToRemove = bufferToRemoveEntity.getLeft(); minPriority = bufferToRemoveEntity.getRight(); } if (bufferToRemove != null) { - addEntity(hotEntities, entity, detectorId); + addEntity(hotEntities, entity, configId); // reset minPriority after the replacement so that we need to iterate all CacheBuffer // again minPriority = Float.MIN_VALUE; } else { // after trying everything, we can now safely put this to cold entities list - addEntity(coldEntities, entity, detectorId); + addEntity(coldEntities, entity, configId); } } return Pair.of(hotEntities, coldEntities); } - private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detectorId) { - CacheBuffer buffer = activeEnities.get(detectorId); + private CacheBufferType computeBufferIfAbsent(Config config, String configId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { - long requiredBytes = getRequiredMemory(detector, dedicatedCacheSize); + long requiredBytes = getRequiredMemory(config, config.isHighCardinality() ? hcDedicatedCacheSize : 1); if (memoryTracker.canAllocateReserved(requiredBytes)) { - memoryTracker.consumeMemory(requiredBytes, true, Origin.REAL_TIME_DETECTOR); - long intervalSecs = detector.getIntervalInSeconds(); - - buffer = new CacheBuffer( - dedicatedCacheSize, - intervalSecs, - getRequiredMemory(detector, 1), - memoryTracker, - clock, - modelTtl, - detectorId, - checkpointWriteQueue, - checkpointMaintainQueue, - checkpointIntervalHrs - ); - activeEnities.put(detectorId, buffer); + memoryTracker.consumeMemory(requiredBytes, true, origin); + buffer = createEmptyCacheBuffer(config, requiredBytes); + activeEnities.put(configId, buffer); // There can be race conditions between tryClearUpMemory and // activeEntities.put above as tryClearUpMemory accesses activeEnities too. // Put tryClearUpMemory after consumeMemory to prevent that. tryClearUpMemory(); } else { - throw new LimitExceededException(detectorId, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); + throw new LimitExceededException(configId, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); } } @@ -484,20 +472,12 @@ private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detec /** * - * @param detector Detector config accessor + * @param config Detector config accessor * @param numberOfEntity number of entities * @return Memory in bytes required for hosting numberOfEntity entities */ - private long getRequiredMemory(AnomalyDetector detector, int numberOfEntity) { - int dimension = detector.getEnabledFeatureIds().size() * detector.getShingleSize(); - return numberOfEntity * memoryTracker - .estimateTRCFModelSize( - dimension, - numberOfTrees, - TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO, - detector.getShingleSize().intValue(), - true - ); + private long getRequiredMemory(Config config, int numberOfEntity) { + return numberOfEntity * getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees); } /** @@ -511,12 +491,12 @@ private long getRequiredMemory(AnomalyDetector detector, int numberOfEntity) { * @param candidatePriority the candidate entity's priority * @return the CacheBuffer if we can find a CacheBuffer to make room for the candidate entity */ - private Triple canReplaceInSharedCache(CacheBuffer originBuffer, float candidatePriority) { - CacheBuffer minPriorityBuffer = null; + private Triple canReplaceInSharedCache(CacheBufferType originBuffer, float candidatePriority) { + CacheBufferType minPriorityBuffer = null; float minPriority = candidatePriority; String minPriorityEntityModelId = null; - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); if (buffer != originBuffer && buffer.canRemove()) { Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); if (!priorityEntry.isPresent()) { @@ -536,12 +516,12 @@ private Triple canReplaceInSharedCache(CacheBuffer o /** * Clear up overused memory. Can happen due to race condition or other detectors * consumes resources from shared memory. - * tryClearUpMemory is ran using AD threadpool because the function is expensive. + * tryClearUpMemory is ran using analysis-specific threadpool because the function is expensive. */ private void tryClearUpMemory() { try { if (maintenanceLock.tryLock()) { - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> clearMemory()); + threadPool.executor(threadPoolName).execute(() -> clearMemory()); } else { threadPool.schedule(() -> { try { @@ -549,7 +529,7 @@ private void tryClearUpMemory() { } catch (Exception e) { LOG.error("Fail to clear up memory taken by CacheBuffer. Will retry during maintenance."); } - }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), threadPoolName); } } finally { if (maintenanceLock.isHeldByCurrentThread()) { @@ -561,12 +541,12 @@ private void tryClearUpMemory() { private void clearMemory() { recalculateUsedMemory(); long memoryToShed = memoryTracker.memoryToShed(); - PriorityQueue> removalCandiates = null; + PriorityQueue> removalCandiates = null; if (memoryToShed > 0) { // sort the triple in an ascending order of priority removalCandiates = new PriorityQueue<>((x, y) -> Float.compare(x.getLeft(), y.getLeft())); - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); if (!priorityEntry.isPresent()) { continue; @@ -579,12 +559,12 @@ private void clearMemory() { } while (memoryToShed > 0) { if (false == removalCandiates.isEmpty()) { - Triple toRemove = removalCandiates.poll(); - CacheBuffer minPriorityBuffer = toRemove.getMiddle(); + Triple toRemove = removalCandiates.poll(); + CacheBufferType minPriorityBuffer = toRemove.getMiddle(); String minPriorityEntityModelId = toRemove.getRight(); - ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); - memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerEntity(); + ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); + memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerModel(); addIntoInactiveCache(removed); if (minPriorityBuffer.canRemove()) { @@ -609,12 +589,12 @@ private void clearMemory() { private void recalculateUsedMemory() { long reserved = 0; long shared = 0; - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); reserved += buffer.getReservedBytes(); shared += buffer.getBytesInSharedCache(); } - memoryTracker.syncMemoryState(Origin.REAL_TIME_DETECTOR, reserved + shared, reserved); + memoryTracker.syncMemoryState(origin, reserved + shared, reserved); } /** @@ -630,15 +610,15 @@ public void maintenance() { // clean up memory if we allocate more memory than we should tryClearUpMemory(); activeEnities.entrySet().stream().forEach(cacheBufferEntry -> { - String detectorId = cacheBufferEntry.getKey(); - CacheBuffer cacheBuffer = cacheBufferEntry.getValue(); + String configId = cacheBufferEntry.getKey(); + CacheBufferType cacheBuffer = cacheBufferEntry.getValue(); // remove expired cache buffer if (cacheBuffer.expired(modelTtl)) { - activeEnities.remove(detectorId); + activeEnities.remove(configId); cacheBuffer.clear(); } else { - List> removedStates = cacheBuffer.maintenance(); - for (ModelState state : removedStates) { + List> removedStates = cacheBuffer.maintenance(); + for (ModelState state : removedStates) { addIntoInactiveCache(state); } } @@ -647,11 +627,11 @@ public void maintenance() { maintainInactiveCache(); doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { - String detectorId = doorKeeperEntry.getKey(); + String configId = doorKeeperEntry.getKey(); DoorKeeper doorKeeper = doorKeeperEntry.getValue(); // doorKeeper has its own state ttl if (doorKeeper.expired(null)) { - doorKeepers.remove(detectorId); + doorKeepers.remove(configId); } else { doorKeeper.maintenance(); } @@ -666,19 +646,19 @@ public void maintenance() { /** * Permanently deletes models hosted in memory and persisted in index. * - * @param detectorId id the of the detector for which models are to be permanently deleted + * @param configId id the of the config for which models are to be permanently deleted */ @Override - public void clear(String detectorId) { - if (Strings.isEmpty(detectorId)) { + public void clear(String configId) { + if (Strings.isEmpty(configId)) { return; } - CacheBuffer buffer = activeEnities.remove(detectorId); + CacheBufferType buffer = activeEnities.remove(configId); if (buffer != null) { buffer.clear(); } - checkpointDao.deleteModelCheckpointByDetectorId(detectorId); - doorKeepers.remove(detectorId); + checkpointDao.deleteModelCheckpointByConfigId(configId); + doorKeepers.remove(configId); } /** @@ -688,7 +668,7 @@ public void clear(String detectorId) { */ @Override public int getActiveEntities(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); if (cacheBuffer != null) { return cacheBuffer.getActiveEntities(); } @@ -697,13 +677,13 @@ public int getActiveEntities(String detectorId) { /** * Whether an entity is active or not - * @param detectorId The Id of the detector that an entity belongs to + * @param configId The Id of the detector that an entity belongs to * @param entityModelId Entity's Model Id * @return Whether an entity is active or not */ @Override - public boolean isActive(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public boolean isActive(String configId, String entityModelId) { + CacheBufferType cacheBuffer = activeEnities.get(configId); if (cacheBuffer != null) { return cacheBuffer.isActive(entityModelId); } @@ -711,32 +691,22 @@ public boolean isActive(String detectorId, String entityModelId) { } @Override - public long getTotalUpdates(String detectorId) { + public long getTotalUpdates(String configId) { return Optional .of(activeEnities) - .map(entities -> entities.get(detectorId)) + .map(entities -> entities.get(configId)) .map(buffer -> buffer.getPriorityTracker().getHighestPriorityEntityId()) .map(entityModelIdOptional -> entityModelIdOptional.get()) - .map(entityModelId -> getTotalUpdates(detectorId, entityModelId)) + .map(entityModelId -> getTotalUpdates(configId, entityModelId)) .orElse(0L); } @Override - public long getTotalUpdates(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - if (cacheBuffer != null) { - Optional modelOptional = cacheBuffer.getModel(entityModelId); - // TODO: make it work for shingles. samples.size() is not the real shingle - long accumulatedShingles = modelOptional - .flatMap(model -> model.getTrcf()) - .map(trcf -> trcf.getForest()) - .map(rcf -> rcf.getTotalUpdates()) - .orElseGet( - () -> modelOptional.map(model -> model.getSamples()).map(samples -> samples.size()).map(Long::valueOf).orElse(0L) - ); - return accumulatedShingles; - } - return 0L; + public long getTotalUpdates(String configId, String entityModelId) { + return Optional + .ofNullable(activeEnities.get(configId)) + .map(cacheBuffer -> getTotalUpdates(cacheBuffer.getModelState(entityModelId))) + .orElse(0L); } /** @@ -756,24 +726,25 @@ public int getTotalActiveEntities() { * @return list of modelStates */ @Override - public List> getAllModels() { - List> states = new ArrayList<>(); - activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModels())); + public List> getAllModels() { + List> states = new ArrayList<>(); + activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModelStates())); return states; } /** - * Gets all of a detector's model sizes hosted on a node + * Gets all of a config's model sizes hosted on a node * + * @param configId config Id * @return a map of model id to its memory size */ @Override - public Map getModelSize(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public Map getModelSize(String configId) { + CacheBufferType cacheBuffer = activeEnities.get(configId); Map res = new HashMap<>(); if (cacheBuffer != null) { - long size = cacheBuffer.getMemoryConsumptionPerEntity(); - cacheBuffer.getAllModels().forEach(entry -> res.put(entry.getModelId(), size)); + long size = cacheBuffer.getMemoryConsumptionPerModel(); + cacheBuffer.getAllModelStates().forEach(entry -> res.put(entry.getModelId(), size)); } return res; } @@ -792,8 +763,8 @@ public Map getModelSize(String detectorId) { * milliseconds when the entity's state is lastly used. Otherwise, return -1. */ @Override - public long getLastActiveMs(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public long getLastActiveTime(String detectorId, String entityModelId) { + CacheBufferType cacheBuffer = activeEnities.get(detectorId); long lastUsedMs = -1; if (cacheBuffer != null) { lastUsedMs = cacheBuffer.getLastUsedTime(entityModelId); @@ -801,7 +772,7 @@ public long getLastActiveMs(String detectorId, String entityModelId) { return lastUsedMs; } } - ModelState stateInActive = inActiveEntities.getIfPresent(entityModelId); + ModelState stateInActive = inActiveEntities.getIfPresent(entityModelId); if (stateInActive != null) { lastUsedMs = stateInActive.getLastUsedTime().toEpochMilli(); } @@ -815,7 +786,7 @@ public void releaseMemoryForOpenCircuitBreaker() { tryClearUpMemory(); activeEnities.values().stream().forEach(cacheBuffer -> { if (cacheBuffer.canRemove()) { - ModelState removed = cacheBuffer.remove(); + ModelState removed = cacheBuffer.remove(); addIntoInactiveCache(removed); } }); @@ -831,9 +802,9 @@ private void maintainInactiveCache() { inActiveEntities.cleanUp(); // // make sure no model has been stored due to bugs - for (ModelState state : inActiveEntities.asMap().values()) { - EntityModel model = state.getModel(); - if (model != null && model.getTrcf().isPresent()) { + for (ModelState state : inActiveEntities.asMap().values()) { + Optional modelOptional = state.getModel(); + if (modelOptional.isPresent()) { LOG.warn(new ParameterizedMessage("Inactive entity's model is null: [{}]. Maybe there are bugs.", state.getModelId())); state.setModel(null); } @@ -846,8 +817,8 @@ private void maintainInactiveCache() { * Called when dedicated cache size changes. Will adjust existing cache buffer's * cache size */ - private void setDedicatedCacheSizeListener() { - activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(dedicatedCacheSize)); + private void setHCDedicatedCacheSizeListener() { + activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(hcDedicatedCacheSize)); } private void setCheckpointFreqListener() { @@ -856,20 +827,16 @@ private void setCheckpointFreqListener() { @Override public List getAllModelProfile(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - List res = new ArrayList<>(); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); if (cacheBuffer != null) { - long size = cacheBuffer.getMemoryConsumptionPerEntity(); - cacheBuffer.getAllModels().forEach(entry -> { - EntityModel model = entry.getModel(); - Entity entity = null; - if (model != null && model.getEntity().isPresent()) { - entity = model.getEntity().get(); - } - res.add(new ModelProfile(entry.getModelId(), entity, size)); - }); + long size = cacheBuffer.getMemoryConsumptionPerModel(); + return cacheBuffer + .getAllModelStates() + .stream() + .map(entry -> new ModelProfile(entry.getModelId(), entry.getEntity().orElse(null), size)) + .collect(Collectors.toList()); } - return res; + return Collections.emptyList(); } /** @@ -881,14 +848,14 @@ public List getAllModelProfile(String detectorId) { */ @Override public Optional getModelProfile(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - if (cacheBuffer != null && cacheBuffer.getModel(entityModelId).isPresent()) { - EntityModel model = cacheBuffer.getModel(entityModelId).get(); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null && cacheBuffer.getModelState(entityModelId) != null) { + ModelState modelState = cacheBuffer.getModelState(entityModelId); Entity entity = null; - if (model != null && model.getEntity().isPresent()) { - entity = model.getEntity().get(); + if (modelState != null && modelState.getEntity().isPresent()) { + entity = modelState.getEntity().get(); } - return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerEntity())); + return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerModel())); } return Optional.empty(); } @@ -900,11 +867,11 @@ public Optional getModelProfile(String detectorId, String entityMo * @param newDedicatedCacheSize the new dedicated cache size to validate */ private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { - if (this.dedicatedCacheSize < newDedicatedCacheSize) { - int delta = newDedicatedCacheSize - this.dedicatedCacheSize; + if (this.hcDedicatedCacheSize < newDedicatedCacheSize) { + int delta = newDedicatedCacheSize - this.hcDedicatedCacheSize; long totalIncreasedBytes = 0; - for (CacheBuffer cacheBuffer : activeEnities.values()) { - totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerEntity() * delta; + for (CacheBufferType cacheBuffer : activeEnities.values()) { + totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerModel() * delta; } if (false == memoryTracker.canAllocateReserved(totalIncreasedBytes)) { @@ -915,13 +882,13 @@ private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { /** * Get a model state without incurring priority update. Used in maintenance. - * @param detectorId Detector Id + * @param configId Config Id * @param modelId Model Id * @return Model state */ @Override - public Optional> getForMaintainance(String detectorId, String modelId) { - CacheBuffer buffer = activeEnities.get(detectorId); + public Optional> getForMaintainance(String configId, String modelId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { return Optional.empty(); } @@ -929,31 +896,31 @@ public Optional> getForMaintainance(String detectorId, S } /** - * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. - * @param detectorId Detector Id - * @param entityModelId Model Id + * Remove model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param configId config Id + * @param modelId Model Id */ @Override - public void removeEntityModel(String detectorId, String entityModelId) { - CacheBuffer buffer = activeEnities.get(detectorId); + public void removeModel(String configId, String modelId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer != null) { - ModelState removed = null; - if ((removed = buffer.remove(entityModelId, false)) != null) { + ModelState removed = buffer.remove(modelId, false); + if (removed != null) { addIntoInactiveCache(removed); } } checkpointDao .deleteModelCheckpoint( - entityModelId, + modelId, ActionListener .wrap( - r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", entityModelId)), - e -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", entityModelId), e) + r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", modelId)), + e -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", modelId), e) ) ); } - private Cache> createInactiveCache(Duration inactiveEntityTtl, int maxInactiveStates) { + private Cache> createInactiveCache(Duration inactiveEntityTtl, int maxInactiveStates) { return CacheBuilder .newBuilder() .expireAfterAccess(inactiveEntityTtl.toHours(), TimeUnit.HOURS) @@ -961,4 +928,8 @@ private Cache> createInactiveCache(Duration inac .concurrencyLevel(1) .build(); } + + protected abstract ModelState createEmptyModelState(String modelId, String configId); + + protected abstract CacheBufferType createEmptyCacheBuffer(Config config, long memoryConsumptionPerEntity); } diff --git a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java b/src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java similarity index 97% rename from src/main/java/org/opensearch/ad/caching/PriorityTracker.java rename to src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java index 439d67679..07f2087ec 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.util.AbstractMap.SimpleImmutableEntry; @@ -236,7 +236,7 @@ public void updatePriority(String entityId) { * @param entityId Entity Id * @param priority priority */ - protected void addPriority(String entityId, float priority) { + public void addPriority(String entityId, float priority) { PriorityNode node = new PriorityNode(entityId, priority); key2Priority.put(entityId, node); priorityList.add(node); @@ -260,7 +260,7 @@ private void adjustSizeIfRequired() { * Remove an entity in the tracker * @param entityId Entity Id */ - protected void removePriority(String entityId) { + public void removePriority(String entityId) { // remove if the key matches; priority does not matter priorityList.remove(new PriorityNode(entityId, 0)); key2Priority.remove(entityId); @@ -269,7 +269,7 @@ protected void removePriority(String entityId) { /** * Remove all of entities */ - protected void clearPriority() { + public void clearPriority() { key2Priority.clear(); priorityList.clear(); } @@ -292,7 +292,7 @@ protected void clearPriority() { * * @return new priority */ - float getUpdatedPriority(float oldPriority) { + public float getUpdatedPriority(float oldPriority) { long increment = computeWeightedPriorityIncrement(); oldPriority += Math.log(1 + Math.exp(increment - oldPriority)); // if overflow happens, using the most recent decayed count instead. @@ -319,7 +319,7 @@ float getUpdatedPriority(float oldPriority) { * @param currentPriority Current priority * @return the scaled priority */ - float getScaledPriority(float currentPriority) { + public float getScaledPriority(float currentPriority) { return currentPriority - computeWeightedPriorityIncrement(); } diff --git a/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java new file mode 100644 index 000000000..fa5b0c1eb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.caching; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.timeseries.AnalysisModelSize; +import org.opensearch.timeseries.CleanState; +import org.opensearch.timeseries.MaintenanceState; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public interface TimeSeriesCache extends MaintenanceState, CleanState, AnalysisModelSize { + /** + * + * @param config Analysis config + * @param toUpdate Model state candidate + * @return if we can host the given model state + */ + boolean hostIfPossible(Config config, ModelState toUpdate); + + /** + * Get a model state without incurring priority update or load from state from disk. Used in maintenance. + * @param configId Config Id + * @param modelId Model Id + * @return Model state + */ + Optional> getForMaintainance(String configId, String modelId); + + /** + * Get the ModelState associated with the modelId. May or may not load the + * ModelState depending on the underlying cache's memory consumption. + * + * @param modelId Model Id + * @param config config accessor + * @return the ModelState associated with the config or null if no cached item + * for the config + */ + ModelState get(String modelId, Config config); + + /** + * Whether an entity is active or not + * @param configId The Id of the config that an entity belongs to + * @param entityModelId Entity model Id + * @return Whether an entity is active or not + */ + boolean isActive(String configId, String entityModelId); + + /** + * Get total updates of the config's most active entity's RCF model. + * + * @param configId detector id + * @return RCF model total updates of most active entity. + */ + long getTotalUpdates(String configId); + + /** + * Get RCF model total updates of specific entity + * + * @param configId config id + * @param entityModelId entity model id + * @return RCF model total updates of specific entity. + */ + long getTotalUpdates(String configId, String entityModelId); + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + List> getAllModels(); + + /** + * Get the number of active entities of a config + * @param configId Config Id + * @return The number of active entities + */ + int getActiveEntities(String configId); + + /** + * + * @return total active entities in the cache + */ + int getTotalActiveEntities(); + + /** + * Return when the last active time of an entity's state. + * + * If the entity's state is active in the cache, the value indicates when the cache + * is lastly accessed (get/put). If the entity's state is inactive in the cache, + * the value indicates when the cache state is created or when the entity is evicted + * from active entity cache. + * + * @param configId The Id of the config that an entity belongs to + * @param entityModelId Entity's Model Id + * @return if the entity is in the cache, return the timestamp in epoch + * milliseconds when the entity's state is lastly used. Otherwise, return -1. + */ + long getLastActiveTime(String configId, String entityModelId); + + /** + * Release memory when memory circuit breaker is open + */ + void releaseMemoryForOpenCircuitBreaker(); + + /** + * Select candidate entities for which we can load models + * @param cacheMissEntities Cache miss entities + * @param configId Config Id + * @param config Config object + * @return A list of entities that are admitted into the cache as a result of the + * update and the left-over entities + */ + Pair, List> selectUpdateCandidate(Collection cacheMissEntities, String configId, Config config); + + /** + * + * @param configId Detector Id + * @return a detector's model information + */ + List getAllModelProfile(String configId); + + /** + * Gets an entity's model sizes + * + * @param configId Detector Id + * @param entityModelId Entity's model Id + * @return the entity's memory size + */ + Optional getModelProfile(String configId, String entityModelId); + + /** + * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param configId config Id + * @param entityModelId Model Id + */ + void removeModel(String configId, String entityModelId); + + /** + * + * @param config Detector config accessor + * @param memoryTracker memory tracker + * @param numberOfTrees number of trees + * @return Memory in bytes required for hosting one entity model + */ + default long getRequiredMemoryPerEntity(Config config, MemoryTracker memoryTracker, int numberOfTrees) { + int dimension = config.getEnabledFeatureIds().size() * config.getShingleSize(); + return memoryTracker + .estimateTRCFModelSize( + dimension, + numberOfTrees, + TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO, + config.getShingleSize().intValue(), + true + ); + } + + default long getTotalUpdates(ModelState modelState) { + // TODO: make it work for shingles. samples.size() is not the real shingle + long accumulatedShingles = Optional + .ofNullable(modelState) + .flatMap(model -> model.getModel()) + .map(trcf -> trcf.getForest()) + .map(rcf -> rcf.getTotalUpdates()) + .orElseGet( + () -> Optional + .ofNullable(modelState) + .map(model -> model.getSamples()) + .map(samples -> samples.size()) + .map(Long::valueOf) + .orElse(0L) + ); + return accumulatedShingles; + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java b/src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java similarity index 97% rename from src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java rename to src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java index 4050c22f5..7dfcf37fb 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java @@ -9,12 +9,10 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_DETECTOR_UPPER_LIMIT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.model.TaskType.taskTypeToString; @@ -59,6 +57,7 @@ import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; import org.opensearch.timeseries.util.ExceptionUtil; /** @@ -212,10 +211,10 @@ private void checkIfRealtimeTaskExistsAndBackfill( BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, jobId)); if (job.isEnabled()) { - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); } - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(ADTaskType.REALTIME_TASK_TYPES))); + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(ADTaskType.REALTIME_TASK_TYPES))); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(1); SearchRequest searchRequest = new SearchRequest(DETECTION_STATE_INDEX).source(searchSourceBuilder); client.search(searchRequest, ActionListener.wrap(r -> { diff --git a/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java b/src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java similarity index 72% rename from src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java rename to src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java index 4f629c7bb..3712bfb73 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.util.concurrent.Semaphore; @@ -23,18 +23,18 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.gateway.GatewayService; -public class ADClusterEventListener implements ClusterStateListener { - private static final Logger LOG = LogManager.getLogger(ADClusterEventListener.class); - static final String NOT_RECOVERED_MSG = "Cluster is not recovered yet."; - static final String IN_PROGRESS_MSG = "Cluster state change in progress, return."; - static final String NODE_CHANGED_MSG = "Cluster node changed"; +public class ClusterEventListener implements ClusterStateListener { + private static final Logger LOG = LogManager.getLogger(ClusterEventListener.class); + public static final String NOT_RECOVERED_MSG = "Cluster is not recovered yet."; + public static final String IN_PROGRESS_MSG = "Cluster state change in progress, return."; + public static final String NODE_CHANGED_MSG = "Cluster node changed"; private final Semaphore inProgress; private HashRing hashRing; private final ClusterService clusterService; @Inject - public ADClusterEventListener(ClusterService clusterService, HashRing hashRing) { + public ClusterEventListener(ClusterService clusterService, HashRing hashRing) { this.clusterService = clusterService; this.clusterService.addListener(this); this.hashRing = hashRing; @@ -55,16 +55,13 @@ public void clusterChanged(ClusterChangedEvent event) { } try { - // Init AD version hash ring as early as possible. Some test case may fail as AD + // Init version hash ring as early as possible. Some test case may fail as AD // version hash ring not initialized when test run. if (!hashRing.isHashRingInited()) { hashRing .buildCircles( ActionListener - .wrap( - r -> LOG.info("Init AD version hash ring successfully"), - e -> LOG.error("Failed to init AD version hash ring") - ) + .wrap(r -> LOG.info("Init version hash ring successfully"), e -> LOG.error("Failed to init version hash ring")) ); } Delta delta = event.nodesDelta(); @@ -74,7 +71,7 @@ public void clusterChanged(ClusterChangedEvent event) { hashRing.addNodeChangeEvent(); hashRing.buildCircles(delta, ActionListener.runAfter(ActionListener.wrap(hasRingBuildDone -> { LOG.info("Hash ring build result: {}", hasRingBuildDone); - }, e -> { LOG.error("Failed updating AD version hash ring", e); }), () -> inProgress.release())); + }, e -> { LOG.error("Failed updating version hash ring", e); }), () -> inProgress.release())); } else { inProgress.release(); } diff --git a/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java b/src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java similarity index 56% rename from src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java rename to src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java index 8b8a40405..e4bd8ae1d 100644 --- a/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java @@ -9,14 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.time.Clock; import java.time.Duration; +import java.util.Arrays; +import java.util.List; -import org.opensearch.ad.cluster.diskcleanup.IndexCleanup; -import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; -import org.opensearch.ad.util.DateUtils; +import org.opensearch.ad.cluster.diskcleanup.ADCheckpointIndexRetention; import org.opensearch.client.Client; import org.opensearch.cluster.LocalNodeClusterManagerListener; import org.opensearch.cluster.service.ClusterService; @@ -24,16 +24,20 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.cluster.diskcleanup.ForecastCheckpointIndexRetention; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.DateUtils; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.annotations.VisibleForTesting; public class ClusterManagerEventListener implements LocalNodeClusterManagerListener { - private Cancellable checkpointIndexRetentionCron; + private Cancellable adCheckpointIndexRetentionCron; + private Cancellable forecastCheckpointIndexRetentionCron; private Cancellable hourlyCron; private ClusterService clusterService; private ThreadPool threadPool; @@ -41,7 +45,8 @@ public class ClusterManagerEventListener implements LocalNodeClusterManagerListe private Clock clock; private ClientUtil clientUtil; private DiscoveryNodeFilterer nodeFilter; - private Duration checkpointTtlDuration; + private Duration adCheckpointTtlDuration; + private Duration forecastCheckpointTtlDuration; public ClusterManagerEventListener( ClusterService clusterService, @@ -50,7 +55,8 @@ public ClusterManagerEventListener( Clock clock, ClientUtil clientUtil, DiscoveryNodeFilterer nodeFilter, - Setting checkpointTtl, + Setting adCheckpointTtl, + Setting forecastCheckpointTtl, Settings settings ) { this.clusterService = clusterService; @@ -61,15 +67,22 @@ public ClusterManagerEventListener( this.clientUtil = clientUtil; this.nodeFilter = nodeFilter; - this.checkpointTtlDuration = DateUtils.toDuration(checkpointTtl.get(settings)); + this.adCheckpointTtlDuration = DateUtils.toDuration(adCheckpointTtl.get(settings)); + this.forecastCheckpointTtlDuration = DateUtils.toDuration(forecastCheckpointTtl.get(settings)); - clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointTtl, it -> { - this.checkpointTtlDuration = DateUtils.toDuration(it); - cancel(checkpointIndexRetentionCron); + clusterService.getClusterSettings().addSettingsUpdateConsumer(adCheckpointTtl, it -> { + this.adCheckpointTtlDuration = DateUtils.toDuration(it); + cancel(adCheckpointIndexRetentionCron); IndexCleanup indexCleanup = new IndexCleanup(client, clientUtil, clusterService); - checkpointIndexRetentionCron = threadPool + adCheckpointIndexRetentionCron = threadPool .scheduleWithFixedDelay( - new ModelCheckpointIndexRetention(checkpointTtlDuration, clock, indexCleanup), + new ADCheckpointIndexRetention(adCheckpointTtlDuration, clock, indexCleanup), + TimeValue.timeValueHours(24), + executorName() + ); + forecastCheckpointIndexRetentionCron = threadPool + .scheduleWithFixedDelay( + new ForecastCheckpointIndexRetention(forecastCheckpointTtlDuration, clock, indexCleanup), TimeValue.timeValueHours(24), executorName() ); @@ -89,19 +102,27 @@ public void beforeStop() { }); } - if (checkpointIndexRetentionCron == null) { + if (adCheckpointIndexRetentionCron == null) { IndexCleanup indexCleanup = new IndexCleanup(client, clientUtil, clusterService); - checkpointIndexRetentionCron = threadPool + adCheckpointIndexRetentionCron = threadPool + .scheduleWithFixedDelay( + new ADCheckpointIndexRetention(adCheckpointTtlDuration, clock, indexCleanup), + TimeValue.timeValueHours(24), + executorName() + ); + forecastCheckpointIndexRetentionCron = threadPool .scheduleWithFixedDelay( - new ModelCheckpointIndexRetention(checkpointTtlDuration, clock, indexCleanup), + new ForecastCheckpointIndexRetention(forecastCheckpointTtlDuration, clock, indexCleanup), TimeValue.timeValueHours(24), executorName() ); clusterService.addLifecycleListener(new LifecycleListener() { @Override public void beforeStop() { - cancel(checkpointIndexRetentionCron); - checkpointIndexRetentionCron = null; + cancel(adCheckpointIndexRetentionCron); + adCheckpointIndexRetentionCron = null; + cancel(forecastCheckpointIndexRetentionCron); + forecastCheckpointIndexRetentionCron = null; } }); } @@ -110,9 +131,11 @@ public void beforeStop() { @Override public void offClusterManager() { cancel(hourlyCron); - cancel(checkpointIndexRetentionCron); hourlyCron = null; - checkpointIndexRetentionCron = null; + cancel(adCheckpointIndexRetentionCron); + adCheckpointIndexRetentionCron = null; + cancel(forecastCheckpointIndexRetentionCron); + forecastCheckpointIndexRetentionCron = null; } private void cancel(Cancellable cron) { @@ -122,11 +145,11 @@ private void cancel(Cancellable cron) { } @VisibleForTesting - protected Cancellable getCheckpointIndexRetentionCron() { - return checkpointIndexRetentionCron; + public List getCheckpointIndexRetentionCron() { + return Arrays.asList(adCheckpointIndexRetentionCron, forecastCheckpointIndexRetentionCron); } - protected Cancellable getHourlyCron() { + public Cancellable getHourlyCron() { return hourlyCron; } diff --git a/src/main/java/org/opensearch/ad/cluster/DailyCron.java b/src/main/java/org/opensearch/timeseries/cluster/DailyCron.java similarity index 86% rename from src/main/java/org/opensearch/ad/cluster/DailyCron.java rename to src/main/java/org/opensearch/timeseries/cluster/DailyCron.java index 2692608d2..4a7c9a6d5 100644 --- a/src/main/java/org/opensearch/ad/cluster/DailyCron.java +++ b/src/main/java/org/opensearch/timeseries/cluster/DailyCron.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.time.Clock; import java.time.Duration; @@ -30,9 +30,9 @@ public class DailyCron implements Runnable { private static final Logger LOG = LogManager.getLogger(DailyCron.class); protected static final String FIELD_MODEL = "queue"; - static final String CANNOT_DELETE_OLD_CHECKPOINT_MSG = "Cannot delete old checkpoint."; - static final String CHECKPOINT_NOT_EXIST_MSG = "Checkpoint index does not exist."; - static final String CHECKPOINT_DELETED_MSG = "checkpoint docs get deleted"; + public static final String CANNOT_DELETE_OLD_CHECKPOINT_MSG = "Cannot delete old checkpoint."; + public static final String CHECKPOINT_NOT_EXIST_MSG = "Checkpoint index does not exist."; + public static final String CHECKPOINT_DELETED_MSG = "checkpoint docs get deleted"; private final Clock clock; private final Duration checkpointTtl; @@ -54,7 +54,7 @@ public void run() { QueryBuilders .rangeQuery(CommonName.TIMESTAMP) .lte(clock.millis() - checkpointTtl.toMillis()) - .format(ADCommonName.EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) ) ) .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); diff --git a/src/main/java/org/opensearch/ad/cluster/HashRing.java b/src/main/java/org/opensearch/timeseries/cluster/HashRing.java similarity index 76% rename from src/main/java/org/opensearch/ad/cluster/HashRing.java rename to src/main/java/org/opensearch/timeseries/cluster/HashRing.java index 30ea1724f..759d70113 100644 --- a/src/main/java/org/opensearch/ad/cluster/HashRing.java +++ b/src/main/java/org/opensearch/timeseries/cluster/HashRing.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; @@ -35,7 +35,7 @@ import org.opensearch.action.admin.cluster.node.info.NodeInfo; import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -69,11 +69,11 @@ public class HashRing { // Semaphore to control only 1 thread can build AD hash ring. private Semaphore buildHashRingSemaphore; - // This field is to track AD version of all nodes. - // Key: node id; Value: AD node info - private Map nodeAdVersions; - // This field records AD version hash ring in realtime way. Historical detection will use this hash ring. - // Key: AD version; Value: hash ring which only contains eligible data nodes + // This field is to track time series plugin version of all nodes. + // Key: node id; Value: node info + private Map nodeVersions; + // This field records time series version hash ring in realtime way. Historical detection will use this hash ring. + // Key: time series version; Value: hash ring which only contains eligible data nodes private TreeMap> circles; // Track if hash ring inited or not. If not inited, the first clusterManager event will try to init it. private AtomicBoolean hashRingInited; @@ -82,8 +82,8 @@ public class HashRing { private long lastUpdateForRealtimeAD; // Cool down period before next hash ring rebuild. We need this as realtime AD needs stable hash ring. private volatile TimeValue coolDownPeriodForRealtimeAD; - // This field records AD version hash ring with cooldown period. Realtime job will use this hash ring. - // Key: AD version; Value: hash ring which only contains eligible data nodes + // This field records time series version hash ring with cooldown period. Realtime job will use this hash ring. + // Key: time series version; Value: hash ring which only contains eligible data nodes private TreeMap> circlesForRealtimeAD; // Record node change event. Will check if there is node change event when rebuild AD hash ring with @@ -95,7 +95,7 @@ public class HashRing { private final ADDataMigrator dataMigrator; private final Clock clock; private final Client client; - private final ModelManager modelManager; + private final ADModelManager modelManager; public HashRing( DiscoveryNodeFilterer nodeFilter, @@ -104,7 +104,7 @@ public HashRing( Client client, ClusterService clusterService, ADDataMigrator dataMigrator, - ModelManager modelManager + ADModelManager modelManager ) { this.nodeFilter = nodeFilter; this.buildHashRingSemaphore = new Semaphore(1); @@ -116,7 +116,7 @@ public HashRing( this.client = client; this.clusterService = clusterService; this.dataMigrator = dataMigrator; - this.nodeAdVersions = new ConcurrentHashMap<>(); + this.nodeVersions = new ConcurrentHashMap<>(); this.circles = new TreeMap<>(); this.circlesForRealtimeAD = new TreeMap<>(); this.hashRingInited = new AtomicBoolean(false); @@ -129,17 +129,17 @@ public boolean isHashRingInited() { } /** - * Build AD version based circles with discovery node delta change. Listen to clusterManager event in - * {@link ADClusterEventListener#clusterChanged(ClusterChangedEvent)}. + * Build version based circles with discovery node delta change. Listen to clusterManager event in + * {@link ClusterEventListener#clusterChanged(ClusterChangedEvent)}. * Will remove the removed nodes from cache and send request to newly added nodes to get their - * plugin information; then add new nodes to AD version hash ring. + * plugin information; then add new nodes to version hash ring. * * @param delta discovery node delta change * @param listener action listener */ public void buildCircles(DiscoveryNodes.Delta delta, ActionListener listener) { if (!buildHashRingSemaphore.tryAcquire()) { - LOG.info("AD version hash ring change is in progress. Can't build hash ring for node delta event."); + LOG.info("hash ring change is in progress. Can't build hash ring for node delta event."); listener.onResponse(false); return; } @@ -151,14 +151,14 @@ public void buildCircles(DiscoveryNodes.Delta delta, ActionListener lis } /** - * Build AD version based circles by comparing with all eligible data nodes. + * Build version based circles by comparing with all eligible data nodes. * 1. Remove nodes which are not eligible now; - * 2. Add nodes which are not in AD version circles. + * 2. Add nodes which are not in version circles. * @param actionListener action listener */ public void buildCircles(ActionListener actionListener) { if (!buildHashRingSemaphore.tryAcquire()) { - LOG.info("AD version hash ring change is in progress. Can't rebuild hash ring."); + LOG.info("hash ring change is in progress. Can't rebuild hash ring."); actionListener.onResponse(false); return; } @@ -167,35 +167,35 @@ public void buildCircles(ActionListener actionListener) { for (DiscoveryNode node : allNodes) { nodeIds.add(node.getId()); } - Set currentNodeIds = nodeAdVersions.keySet(); + Set currentNodeIds = nodeVersions.keySet(); Set removedNodeIds = Sets.difference(currentNodeIds, nodeIds); Set addedNodeIds = Sets.difference(nodeIds, currentNodeIds); buildCircles(removedNodeIds, addedNodeIds, actionListener); } - public void buildCirclesForRealtimeAD() { + public void buildCirclesForRealtime() { if (nodeChangeEvents.isEmpty()) { return; } - buildCircles(ActionListener.wrap(r -> { LOG.debug("build circles on AD versions successfully"); }, e -> { - LOG.error("Failed to build circles on AD versions", e); - })); + buildCircles( + ActionListener.wrap(r -> { LOG.debug("build circles successfully"); }, e -> { LOG.error("Failed to build circles", e); }) + ); } /** - * Build AD version hash ring. - * 1. Delete removed nodes from AD version hash ring. - * 2. Add new nodes to AD version hash ring + * Build version hash ring. + * 1. Delete removed nodes from version hash ring. + * 2. Add new nodes to version hash ring * - * If fail to acquire semaphore to update AD version hash ring, will return false to + * If fail to acquire semaphore to update version hash ring, will return false to * action listener; otherwise will return true. The "true" response just mean we got * semaphore and finished rebuilding hash ring, but the hash ring may stay the same. * Hash ring changed or not depends on if "removedNodeIds" or "addedNodeIds" is empty. * * We use different way to build hash ring for realtime job and historical analysis - * 1. For historical analysis,if node removed, we remove it immediately from adVersionCircles - * to avoid new AD task routes to it. If new node added, we add it immediately to adVersionCircles - * to make load more balanced and speed up AD task running. + * 1. For historical analysis,if node removed, we remove it immediately from version circles + * to avoid new task routes to it. If new node added, we add it immediately to version circles + * to make load more balanced and speed up task running. * 2. For realtime job, we don't record which node running detector's model partition. We just * use hash ring to get owning node. If we rebuild hash ring frequently, realtime job may get * different owning node and need to restore model on new owning node. If that happens a lot, @@ -205,7 +205,7 @@ public void buildCirclesForRealtimeAD() { * and still send RCF request to it. If new node added during cooldown period, realtime job won't * choose it as model partition owning node, thus we may have skewed load on data nodes. * - * [Important!]: When you call this function, make sure you TRY ACQUIRE adVersionCircleInProgress first. + * [Important!]: When you call this function, make sure you TRY ACQUIRE buildHashRingSemaphore first. * Check {@link HashRing#buildCircles(ActionListener)} and * {@link HashRing#buildCircles(DiscoveryNodes.Delta, ActionListener)} * @@ -222,10 +222,10 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, if (removedNodeIds != null && removedNodeIds.size() > 0) { LOG.info("Node removed: {}", Arrays.toString(removedNodeIds.toArray(new String[0]))); for (String nodeId : removedNodeIds) { - ADNodeInfo nodeInfo = nodeAdVersions.remove(nodeId); + TimeSeriesNodeInfo nodeInfo = nodeVersions.remove(nodeId); if (nodeInfo != null && nodeInfo.isEligibleDataNode()) { - removeNodeFromCircles(nodeId, nodeInfo.getAdVersion()); - LOG.info("Remove data node from AD version hash ring: {}", nodeId); + removeNodeFromCircles(nodeId, nodeInfo.getVersion()); + LOG.info("Remove data node from version hash ring: {}", nodeId); } } } @@ -234,12 +234,12 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, if (addedNodeIds != null) { allAddedNodes.addAll(addedNodeIds); } - if (!nodeAdVersions.containsKey(localNode.getId())) { + if (!nodeVersions.containsKey(localNode.getId())) { allAddedNodes.add(localNode.getId()); } if (allAddedNodes.size() == 0) { actionListener.onResponse(true); - // rebuild AD version hash ring with cooldown. + // rebuild version hash ring with cooldown. rebuildCirclesForRealtimeAD(); buildHashRingSemaphore.release(); return; @@ -264,15 +264,16 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, } TreeMap circle = null; for (PluginInfo pluginInfo : plugins.getPluginInfos()) { + // if (AD_PLUGIN_NAME.equals(pluginInfo.getName()) || AD_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { if (CommonName.TIME_SERIES_PLUGIN_NAME.equals(pluginInfo.getName()) || CommonName.TIME_SERIES_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { - Version version = ADVersionUtil.fromString(pluginInfo.getVersion()); + Version version = VersionUtil.fromString(pluginInfo.getVersion()); boolean eligibleNode = nodeFilter.isEligibleNode(curNode); if (eligibleNode) { circle = circles.computeIfAbsent(version, key -> new TreeMap<>()); - LOG.info("Add data node to AD version hash ring: {}", curNode.getId()); + LOG.info("Add data node to version hash ring: {}", curNode.getId()); } - nodeAdVersions.put(curNode.getId(), new ADNodeInfo(version, eligibleNode)); + nodeVersions.put(curNode.getId(), new TimeSeriesNodeInfo(version, eligibleNode)); break; } } @@ -283,15 +284,15 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, } } } - LOG.info("All nodes with known AD version: {}", nodeAdVersions); + LOG.info("All nodes with known version: {}", nodeVersions); - // rebuild AD version hash ring with cooldown after all new node added. + // rebuild version hash ring with cooldown after all new node added. rebuildCirclesForRealtimeAD(); if (!dataMigrator.isMigrated() && circles.size() > 0) { - // Find owning node with highest AD version to make sure the data migration logic be compatible to - // latest AD version when upgrade. - Optional owningNode = getOwningNodeWithHighestAdVersion(DEFAULT_HASH_RING_MODEL_ID); + // Find owning node with highest version to make sure the data migration logic be compatible to + // latest version when upgrade. + Optional owningNode = getOwningNodeWithHighestVersion(DEFAULT_HASH_RING_MODEL_ID); String localNodeId = localNode.getId(); if (owningNode.isPresent() && localNodeId.equals(owningNode.get().getId())) { dataMigrator.migrateData(); @@ -305,18 +306,18 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, }, e -> { buildHashRingSemaphore.release(); actionListener.onFailure(e); - LOG.error("Fail to get node info to build AD version hash ring", e); + LOG.error("Fail to get node info to build hash ring", e); })); } catch (Exception e) { - LOG.error("Failed to build AD version circles", e); + LOG.error("Failed to build circles", e); buildHashRingSemaphore.release(); actionListener.onFailure(e); } } - private void removeNodeFromCircles(String nodeId, Version adVersion) { - if (adVersion != null) { - TreeMap circle = this.circles.get(adVersion); + private void removeNodeFromCircles(String nodeId, Version version) { + if (version != null) { + TreeMap circle = this.circles.get(version); List deleted = new ArrayList<>(); for (Map.Entry entry : circle.entrySet()) { if (entry.getValue().getId().equals(nodeId)) { @@ -324,7 +325,7 @@ private void removeNodeFromCircles(String nodeId, Version adVersion) { } } if (deleted.size() == circle.size()) { - circles.remove(adVersion); + circles.remove(version); } else { for (Integer key : deleted) { circle.remove(key); @@ -336,7 +337,7 @@ private void removeNodeFromCircles(String nodeId, Version adVersion) { private void rebuildCirclesForRealtimeAD() { // Check if it's eligible to rebuild hash ring with cooldown if (eligibleToRebuildCirclesForRealtimeAD()) { - LOG.info("Rebuild AD hash ring for realtime AD with cooldown, nodeChangeEvents size {}", nodeChangeEvents.size()); + LOG.info("Rebuild hash ring for realtime with cooldown, nodeChangeEvents size {}", nodeChangeEvents.size()); int size = nodeChangeEvents.size(); TreeMap> newCircles = new TreeMap<>(); for (Map.Entry> entry : circles.entrySet()) { @@ -344,17 +345,17 @@ private void rebuildCirclesForRealtimeAD() { } circlesForRealtimeAD = newCircles; lastUpdateForRealtimeAD = clock.millis(); - LOG.info("Build AD version hash ring successfully"); + LOG.info("Build version hash ring successfully"); String localNodeId = clusterService.localNode().getId(); Set modelIds = modelManager.getAllModelIds(); for (String modelId : modelIds) { - Optional node = getOwningNodeWithSameLocalAdVersionForRealtimeAD(modelId); + Optional node = getOwningNodeWithSameLocalVersionForRealtime(modelId); if (node.isPresent() && !node.get().getId().equals(localNodeId)) { LOG.info(REMOVE_MODEL_MSG + " {}", modelId); modelManager .stopModel( // stopModel will clear model cache - SingleStreamModelIdMapper.getDetectorIdForModelId(modelId), + SingleStreamModelIdMapper.getConfigIdForModelId(modelId), modelId, ActionListener .wrap( @@ -366,7 +367,7 @@ private void rebuildCirclesForRealtimeAD() { } // It's possible that multiple threads add new event to nodeChangeEvents, // but this is the only place to consume/poll the event and there is only - // one thread poll it as we are using adVersionCircleInProgress semaphore(1) + // one thread poll it as we are using buildHashRingSemaphore // to control only 1 thread build hash ring. while (size-- > 0) { Boolean poll = nodeChangeEvents.poll(); @@ -397,7 +398,7 @@ private void rebuildCirclesForRealtimeAD() { * * @return true if it's eligible to rebuild hash ring */ - protected boolean eligibleToRebuildCirclesForRealtimeAD() { + public boolean eligibleToRebuildCirclesForRealtimeAD() { // Check if there is any node change event if (nodeChangeEvents.isEmpty() && !circlesForRealtimeAD.isEmpty()) { return false; @@ -412,71 +413,71 @@ protected boolean eligibleToRebuildCirclesForRealtimeAD() { } /** - * Get owning node with highest AD version circle. + * Get owning node with highest version circle. * @param modelId model id * @return owning node */ - public Optional getOwningNodeWithHighestAdVersion(String modelId) { + public Optional getOwningNodeWithHighestVersion(String modelId) { int modelHash = Murmur3HashFunction.hash(modelId); Map.Entry> versionTreeMapEntry = circles.lastEntry(); if (versionTreeMapEntry == null) { return Optional.empty(); } - TreeMap adVersionCircle = versionTreeMapEntry.getValue(); - Map.Entry entry = adVersionCircle.higherEntry(modelHash); - return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + TreeMap versionCircle = versionTreeMapEntry.getValue(); + Map.Entry entry = versionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(versionCircle.firstEntry())).map(x -> x.getValue()); } /** - * Get owning node with same AD version of local node. + * Get owning node with same version of local node. * @param modelId model id * @param function consumer function * @param listener action listener * @param listener response type */ - public void buildAndGetOwningNodeWithSameLocalAdVersion( + public void buildAndGetOwningNodeWithSameLocalVersion( String modelId, Consumer> function, ActionListener listener ) { buildCircles(ActionListener.wrap(r -> { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, false); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameVersionDirectly(modelId, version, false); function.accept(owningNode); }, e -> listener.onFailure(e))); } - public Optional getOwningNodeWithSameLocalAdVersionForRealtimeAD(String modelId) { + public Optional getOwningNodeWithSameLocalVersionForRealtime(String modelId) { try { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, true); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameVersionDirectly(modelId, version, true); // rebuild hash ring - buildCirclesForRealtimeAD(); + buildCirclesForRealtime(); return owningNode; } catch (Exception e) { - LOG.error("Failed to get owning node with same local AD version", e); + LOG.error("Failed to get owning node with same local time series version", e); return Optional.empty(); } } - private Optional getOwningNodeWithSameAdVersionDirectly(String modelId, Version adVersion, boolean forRealtime) { + private Optional getOwningNodeWithSameVersionDirectly(String modelId, Version version, boolean forRealtime) { int modelHash = Murmur3HashFunction.hash(modelId); - TreeMap adVersionCircle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); - if (adVersionCircle != null) { - Map.Entry entry = adVersionCircle.higherEntry(modelHash); - return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + TreeMap versionCircle = forRealtime ? circlesForRealtimeAD.get(version) : circles.get(version); + if (versionCircle != null) { + Map.Entry entry = versionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(versionCircle.firstEntry())).map(x -> x.getValue()); } return Optional.empty(); } - public void getNodesWithSameLocalAdVersion(Consumer function, ActionListener listener) { + public void getNodesWithSameLocalVersion(Consumer function, ActionListener listener) { buildCircles(ActionListener.wrap(updated -> { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Set nodes = getNodesWithSameAdVersion(adVersion, false); - if (!nodeAdVersions.containsKey(localNode.getId())) { + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameVersion(version, false); + if (!nodeVersions.containsKey(localNode.getId())) { nodes.add(localNode); } // Make sure listener return in function @@ -484,17 +485,17 @@ public void getNodesWithSameLocalAdVersion(Consumer functio }, e -> listener.onFailure(e))); } - public DiscoveryNode[] getNodesWithSameLocalAdVersion() { + public DiscoveryNode[] getNodesWithSameLocalVersion() { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Set nodes = getNodesWithSameAdVersion(adVersion, false); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameVersion(version, false); // rebuild hash ring - buildCirclesForRealtimeAD(); + buildCirclesForRealtime(); return nodes.toArray(new DiscoveryNode[0]); } - protected Set getNodesWithSameAdVersion(Version adVersion, boolean forRealtime) { - TreeMap circle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); + public Set getNodesWithSameVersion(Version version, boolean forRealtime) { + TreeMap circle = forRealtime ? circlesForRealtimeAD.get(version) : circles.get(version); Set nodeIds = new HashSet<>(); Set nodes = new HashSet<>(); if (circle == null) { @@ -511,13 +512,13 @@ protected Set getNodesWithSameAdVersion(Version adVersion, boolea } /** - * Get AD version. + * Get time series version. * @param nodeId node id - * @return AD version + * @return version */ - public Version getAdVersion(String nodeId) { - ADNodeInfo adNodeInfo = nodeAdVersions.get(nodeId); - return adNodeInfo == null ? null : adNodeInfo.getAdVersion(); + public Version getVersion(String nodeId) { + TimeSeriesNodeInfo nodeInfo = nodeVersions.get(nodeId); + return nodeInfo == null ? null : nodeInfo.getVersion(); } /** @@ -561,17 +562,17 @@ private String getIpAddress(TransportAddress address) { } /** - * Get all eligible data nodes whose AD versions are known in AD version based hash ring. + * Get all eligible data nodes whose time series versions are known in hash ring. * @param function consumer function * @param listener action listener * @param action listener response type */ - public void getAllEligibleDataNodesWithKnownAdVersion(Consumer function, ActionListener listener) { + public void getAllEligibleDataNodesWithKnownVersion(Consumer function, ActionListener listener) { buildCircles(ActionListener.wrap(r -> { DiscoveryNode[] eligibleDataNodes = nodeFilter.getEligibleDataNodes(); List allNodes = new ArrayList<>(); for (DiscoveryNode node : eligibleDataNodes) { - if (nodeAdVersions.containsKey(node.getId())) { + if (nodeVersions.containsKey(node.getId())) { allNodes.add(node); } } diff --git a/src/main/java/org/opensearch/ad/cluster/HourlyCron.java b/src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java similarity index 84% rename from src/main/java/org/opensearch/ad/cluster/HourlyCron.java rename to src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java index 687aace69..4381ade8d 100644 --- a/src/main/java/org/opensearch/ad/cluster/HourlyCron.java +++ b/src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java @@ -9,23 +9,23 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.transport.CronAction; -import org.opensearch.ad.transport.CronRequest; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.transport.CronRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; public class HourlyCron implements Runnable { private static final Logger LOG = LogManager.getLogger(HourlyCron.class); - static final String SUCCEEDS_LOG_MSG = "Hourly maintenance succeeds"; - static final String NODE_EXCEPTION_LOG_MSG = "Hourly maintenance of node has exception"; - static final String EXCEPTION_LOG_MSG = "Hourly maintenance has exception."; + public static final String SUCCEEDS_LOG_MSG = "Hourly maintenance succeeds"; + public static final String NODE_EXCEPTION_LOG_MSG = "Hourly maintenance of node has exception"; + public static final String EXCEPTION_LOG_MSG = "Hourly maintenance has exception."; private DiscoveryNodeFilterer nodeFilter; private Client client; diff --git a/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java b/src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java similarity index 70% rename from src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java rename to src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java index e438623d5..f67d663ae 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java +++ b/src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java @@ -9,25 +9,25 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.opensearch.Version; /** - * This class records AD version of nodes and whether node is eligible data node to run AD. + * This class records time series plugin version of nodes and whether node is eligible data node to run time series analysis. */ -public class ADNodeInfo { - // AD plugin version +public class TimeSeriesNodeInfo { + // time series plugin version private Version adVersion; // Is node eligible to run AD. private boolean isEligibleDataNode; - public ADNodeInfo(Version version, boolean isEligibleDataNode) { + public TimeSeriesNodeInfo(Version version, boolean isEligibleDataNode) { this.adVersion = version; this.isEligibleDataNode = isEligibleDataNode; } - public Version getAdVersion() { + public Version getVersion() { return adVersion; } diff --git a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java b/src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java similarity index 95% rename from src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java rename to src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java index 7e880de66..8d506732d 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java +++ b/src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java @@ -9,12 +9,12 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.opensearch.Version; import org.opensearch.timeseries.constant.CommonName; -public class ADVersionUtil { +public class VersionUtil { public static final int VERSION_SEGMENTS = 3; diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/BaseModelCheckpointIndexRetention.java similarity index 86% rename from src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java rename to src/main/java/org/opensearch/timeseries/cluster/diskcleanup/BaseModelCheckpointIndexRetention.java index 28fc05e37..6885d611e 100644 --- a/src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java +++ b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/BaseModelCheckpointIndexRetention.java @@ -9,14 +9,13 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster.diskcleanup; +package org.opensearch.timeseries.cluster.diskcleanup; import java.time.Clock; import java.time.Duration; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilders; @@ -34,8 +33,8 @@ * We will keep the this logic, and add new clean up way based on shard size. *

*/ -public class ModelCheckpointIndexRetention implements Runnable { - private static final Logger LOG = LogManager.getLogger(ModelCheckpointIndexRetention.class); +public class BaseModelCheckpointIndexRetention implements Runnable { + private static final Logger LOG = LogManager.getLogger(BaseModelCheckpointIndexRetention.class); // The recommended max shard size is 50G, we don't wanna our index exceeds this number private static final long MAX_SHARD_SIZE_IN_BYTE = 50 * 1024 * 1024 * 1024L; @@ -46,25 +45,32 @@ public class ModelCheckpointIndexRetention implements Runnable { private final Duration defaultCheckpointTtl; private final Clock clock; private final IndexCleanup indexCleanup; + private final String checkpointIndexName; - public ModelCheckpointIndexRetention(Duration defaultCheckpointTtl, Clock clock, IndexCleanup indexCleanup) { + public BaseModelCheckpointIndexRetention( + Duration defaultCheckpointTtl, + Clock clock, + IndexCleanup indexCleanup, + String checkpointIndexName + ) { this.defaultCheckpointTtl = defaultCheckpointTtl; this.clock = clock; this.indexCleanup = indexCleanup; + this.checkpointIndexName = checkpointIndexName; } @Override public void run() { indexCleanup .deleteDocsByQuery( - ADCommonName.CHECKPOINT_INDEX_NAME, + checkpointIndexName, QueryBuilders .boolQuery() .filter( QueryBuilders .rangeQuery(CommonName.TIMESTAMP) .lte(clock.millis() - defaultCheckpointTtl.toMillis()) - .format(ADCommonName.EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) ), ActionListener.wrap(response -> { cleanupBasedOnShardSize(defaultCheckpointTtl.minusDays(1)); @@ -79,7 +85,7 @@ public void run() { private void cleanupBasedOnShardSize(Duration cleanUpTtl) { indexCleanup .deleteDocsBasedOnShardSize( - ADCommonName.CHECKPOINT_INDEX_NAME, + checkpointIndexName, MAX_SHARD_SIZE_IN_BYTE, QueryBuilders .boolQuery() @@ -87,7 +93,7 @@ private void cleanupBasedOnShardSize(Duration cleanUpTtl) { QueryBuilders .rangeQuery(CommonName.TIMESTAMP) .lte(clock.millis() - cleanUpTtl.toMillis()) - .format(ADCommonName.EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) ), ActionListener.wrap(cleanupNeeded -> { if (cleanupNeeded) { diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/IndexCleanup.java similarity index 98% rename from src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java rename to src/main/java/org/opensearch/timeseries/cluster/diskcleanup/IndexCleanup.java index bd37127cb..899f41a73 100644 --- a/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java +++ b/src/main/java/org/opensearch/timeseries/cluster/diskcleanup/IndexCleanup.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster.diskcleanup; +package org.opensearch.timeseries.cluster.diskcleanup; import java.util.Arrays; import java.util.Objects; diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java index ae2064add..ae086c109 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java @@ -13,6 +13,8 @@ import java.util.Locale; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + public class CommonMessages { // ====================================== // Validation message @@ -49,6 +51,27 @@ public static String getTooManyCategoricalFieldErr(int limit) { public static final String UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG = "Feature has an unknown exception caught while executing the feature query: "; public static String DUPLICATE_FEATURE_AGGREGATION_NAMES = "Config has duplicate feature aggregation query names: "; + public static String TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA = + "There isn't enough historical data found with current timefield selected."; + public static String CATEGORY_FIELD_TOO_SPARSE = + "Data is most likely too sparse with the given category fields. Consider revising category field/s or ingesting more data "; + public static String WINDOW_DELAY_REC = + "Latest seen data point is at least %d minutes ago, consider changing window delay to at least %d minutes."; + public static String INTERVAL_REC = "The selected interval might collect sparse data. Consider changing interval length to: "; + public static String RAW_DATA_TOO_SPARSE = + "Source index data is potentially too sparse for model training. Consider changing interval length or ingesting more data"; + public static String MODEL_VALIDATION_FAILED_UNEXPECTEDLY = "Model validation experienced issues completing."; + public static String FILTER_QUERY_TOO_SPARSE = "Data is too sparse after data filter is applied. Consider changing the data filter"; + public static String CATEGORY_FIELD_NO_DATA = + "No entity was found with the given categorical fields. Consider revising category field/s or ingesting more data"; + public static String FEATURE_QUERY_TOO_SPARSE = + "Data is most likely too sparse when given feature queries are applied. Consider revising feature queries."; + public static String TIMEOUT_ON_INTERVAL_REC = "Timed out getting interval recommendation"; + public static final String NOT_EXISTENT_VALIDATION_TYPE = "The given validation type doesn't exist"; + public static final String NOT_EXISTENT_SUGGEST_TYPE = "The given suggest type doesn't exist"; + public static final String DESCRIPTION_LENGTH_TOO_LONG = "Description length is too long. Max length is " + + TimeSeriesSettings.MAX_DESCRIPTION_LENGTH + + " characters."; // ====================================== // Index message @@ -61,12 +84,14 @@ public static String getTooManyCategoricalFieldErr(int limit) { // Resource constraints // ====================================== public static final String MEMORY_CIRCUIT_BROKEN_ERR_MSG = - "The total OpenSearch memory usage exceeds our threshold, opening the AD memory circuit."; + "The total OpenSearch memory usage exceeds our threshold, opening the memory circuit."; // ====================================== // Transport // ====================================== public static final String INVALID_TIMESTAMP_ERR_MSG = "timestamp is invalid"; + public static String FAIL_TO_DELETE_CONFIG = "Fail to delete config"; + public static String FAIL_TO_GET_CONFIG_INFO = "Fail to get config info"; // ====================================== // transport/restful client @@ -91,4 +116,34 @@ public static String getTooManyCategoricalFieldErr(int limit) { public static String NO_PERMISSION_TO_ACCESS_CONFIG = "User does not have permissions to access config: "; public static String FAIL_TO_GET_USER_INFO = "Unable to get user information from config "; + // ====================================== + // transport + // ====================================== + public static final String CONFIG_ID_MISSING_MSG = "config ID is missing"; + public static final String MODEL_ID_MISSING_MSG = "model ID is missing"; + + // ====================================== + // task + // ====================================== + public static String CAN_NOT_FIND_LATEST_TASK = "can't find latest task"; + + // ====================================== + // Job + // ====================================== + public static String CONFIG_IS_RUNNING = "Config is already running"; + public static String FAIL_TO_SEARCH = "Fail to search"; + + // ====================================== + // Profile API + // ====================================== + public static String EMPTY_PROFILES_COLLECT = "profiles to collect are missing or invalid"; + public static String FAIL_TO_PARSE_CONFIG_MSG = "Fail to parse config with id: "; + public static String FAIL_FETCH_ERR_MSG = "Fail to fetch profile for "; + public static String FAIL_TO_GET_PROFILE_MSG = "Fail to get profile for config "; + public static String FAIL_TO_GET_TOTAL_ENTITIES = "Failed to get total entities for config "; + + // ====================================== + // Stats API + // ====================================== + public static String FAIL_TO_GET_STATS = "Fail to get stats"; } diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonName.java b/src/main/java/org/opensearch/timeseries/constant/CommonName.java index 0b997ea5d..88c8de185 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonName.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonName.java @@ -11,6 +11,8 @@ package org.opensearch.timeseries.constant; +import org.opensearch.timeseries.stats.StatNames; + public class CommonName { // ====================================== @@ -62,7 +64,6 @@ public class CommonName { public static final String EXECUTION_START_TIME_FIELD = "execution_start_time"; public static final String EXECUTION_END_TIME_FIELD = "execution_end_time"; public static final String ERROR_FIELD = "error"; - public static final String ENTITY_FIELD = "entity"; public static final String USER_FIELD = "user"; public static final String CONFIDENCE_FIELD = "confidence"; public static final String DATA_QUALITY_FIELD = "data_quality"; @@ -70,6 +71,9 @@ public class CommonName { public static final String MODEL_ID_FIELD = "model_id"; public static final String TIMESTAMP = "timestamp"; public static final String FIELD_MODEL = "model"; + public static final String ANALYSIS_TYPE_FIELD = "analysis_type"; + public static final String ANSWER_FIELD = "answer"; + public static final String RUN_ONCE_FIELD = "run_once"; // entity sample in checkpoint. // kept for bwc purpose @@ -105,6 +109,7 @@ public class CommonName { public static final String CONFIG_ID_KEY = "config_id"; public static final String MODEL_ID_KEY = "model_id"; public static final String TASK_ID_FIELD = "task_id"; + public static final String TASK = "task"; public static final String ENTITY_ID_FIELD = "entity_id"; // ====================================== @@ -113,4 +118,36 @@ public class CommonName { public static final String TIME_SERIES_PLUGIN_NAME = "opensearch-time-series-analytics"; public static final String TIME_SERIES_PLUGIN_NAME_FOR_TEST = "org.opensearch.timeseries.TimeSeriesAnalyticsPlugin"; public static final String TIME_SERIES_PLUGIN_VERSION_FOR_TEST = "NA"; + + // ====================================== + // Profile name + // ====================================== + public static final String CATEGORICAL_FIELD = "category_field"; + public static final String STATE = "state"; + public static final String ERROR = "error"; + public static final String COORDINATING_NODE = "coordinating_node"; + public static final String SHINGLE_SIZE = "shingle_size"; + public static final String TOTAL_SIZE_IN_BYTES = "total_size_in_bytes"; + public static final String MODELS = "models"; + public static final String MODEL = "model"; + public static final String INIT_PROGRESS = "init_progress"; + public static final String TOTAL_ENTITIES = "total_entities"; + public static final String ACTIVE_ENTITIES = "active_entities"; + public static final String ENTITY_INFO = "entity_info"; + public static final String TOTAL_UPDATES = "total_updates"; + public static final String MODEL_COUNT = StatNames.MODEL_COUNT.getName(); + + // ====================================== + // Ultrawarm node attributes + // ====================================== + // hot node + public static String HOT_BOX_TYPE = "hot"; + // warm node + public static String WARM_BOX_TYPE = "warm"; + // box type + public static final String BOX_TYPE_KEY = "box_type"; + // ====================================== + // Format name + // ====================================== + public static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; } diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java b/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java index 4e885421c..801489c7b 100644 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java @@ -5,6 +5,8 @@ package org.opensearch.timeseries.dataprocessor; +import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; + /* * An object for imputing feature vectors. * @@ -24,18 +26,20 @@ public abstract class Imputer { * `numFeatures`. * * - * @param samples A `numFeatures x numSamples` list of feature vectors. + * @param samples A `numSamples x numFeatures` list of feature vectors. * @param numImputed The desired number of imputed vectors. - * @return A `numFeatures x numImputed` list of feature vectors. + * @return A `numSamples x numFeatures` list of feature vectors. */ public double[][] impute(double[][] samples, int numImputed) { - int numFeatures = samples.length; - double[][] interpolants = new double[numFeatures][numImputed]; - + // convert to a `numFeatures x numSamples` list of feature vectors + double[][] transposed = transpose(samples); + int numFeatures = transposed.length; + double[][] imputants = new double[numFeatures][numImputed]; for (int featureIndex = 0; featureIndex < numFeatures; featureIndex++) { - interpolants[featureIndex] = singleFeatureImpute(samples[featureIndex], numImputed); + imputants[featureIndex] = singleFeatureImpute(transposed[featureIndex], numImputed); } - return interpolants; + // transpose back to a `numSamples x numFeatures` list of feature vectors + return transpose(imputants); } /** @@ -45,4 +49,8 @@ public double[][] impute(double[][] samples, int numImputed) { * @return input array with missing values imputed */ protected abstract double[] singleFeatureImpute(double[] samples, int numImputed); + + private double[][] transpose(double[][] matrix) { + return createRealMatrix(matrix).transpose().getData(); + } } diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java b/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java index e91c90814..0776dcb8a 100644 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java @@ -18,25 +18,25 @@ public class PreviousValueImputer extends Imputer { @Override - protected double[] singleFeatureImpute(double[] samples, int numInterpolants) { + protected double[] singleFeatureImpute(double[] samples, int numImputants) { int numSamples = samples.length; - double[] interpolants = new double[numSamples]; + double[] imputants = new double[numSamples]; if (numSamples > 0) { - System.arraycopy(samples, 0, interpolants, 0, samples.length); + System.arraycopy(samples, 0, imputants, 0, samples.length); if (numSamples > 1) { double lastKnownValue = Double.NaN; for (int interpolantIndex = 0; interpolantIndex < numSamples; interpolantIndex++) { - if (Double.isNaN(interpolants[interpolantIndex])) { + if (Double.isNaN(imputants[interpolantIndex])) { if (!Double.isNaN(lastKnownValue)) { - interpolants[interpolantIndex] = lastKnownValue; + imputants[interpolantIndex] = lastKnownValue; } } else { - lastKnownValue = interpolants[interpolantIndex]; + lastKnownValue = imputants[interpolantIndex]; } } } } - return interpolants; + return imputants; } } diff --git a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java similarity index 99% rename from src/main/java/org/opensearch/ad/feature/AbstractRetriever.java rename to src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java index 886dbcbc4..5f2609ed5 100644 --- a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Arrays; import java.util.Iterator; diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java similarity index 93% rename from src/main/java/org/opensearch/ad/feature/CompositeRetriever.java rename to src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java index 3c9a1632a..f4cae0c0e 100644 --- a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.io.IOException; import java.time.Clock; @@ -27,7 +27,6 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -46,6 +45,7 @@ import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.util.ParseUtils; @@ -66,7 +66,7 @@ public class CompositeRetriever extends AbstractRetriever { private final long dataStartEpoch; private final long dataEndEpoch; - private final AnomalyDetector anomalyDetector; + private final Config config; private final NamedXContentRegistry xContent; private final Client client; private final SecurityClientUtil clientUtil; @@ -78,11 +78,12 @@ public class CompositeRetriever extends AbstractRetriever { private Clock clock; private IndexNameExpressionResolver indexNameExpressionResolver; private ClusterService clusterService; + private AnalysisType context; public CompositeRetriever( long dataStartEpoch, long dataEndEpoch, - AnomalyDetector anomalyDetector, + Config config, NamedXContentRegistry xContent, Client client, SecurityClientUtil clientUtil, @@ -92,11 +93,12 @@ public CompositeRetriever( int maxEntitiesPerInterval, int pageSize, IndexNameExpressionResolver indexNameExpressionResolver, - ClusterService clusterService + ClusterService clusterService, + AnalysisType context ) { this.dataStartEpoch = dataStartEpoch; this.dataEndEpoch = dataEndEpoch; - this.anomalyDetector = anomalyDetector; + this.config = config; this.xContent = xContent; this.client = client; this.clientUtil = clientUtil; @@ -107,13 +109,14 @@ public CompositeRetriever( this.clock = clock; this.indexNameExpressionResolver = indexNameExpressionResolver; this.clusterService = clusterService; + this.context = context; } // a constructor that provide default value of clock public CompositeRetriever( long dataStartEpoch, long dataEndEpoch, - AnomalyDetector anomalyDetector, + Config anomalyDetector, NamedXContentRegistry xContent, Client client, SecurityClientUtil clientUtil, @@ -122,7 +125,8 @@ public CompositeRetriever( int maxEntitiesPerInterval, int pageSize, IndexNameExpressionResolver indexNameExpressionResolver, - ClusterService clusterService + ClusterService clusterService, + AnalysisType context ) { this( dataStartEpoch, @@ -137,7 +141,8 @@ public CompositeRetriever( maxEntitiesPerInterval, pageSize, indexNameExpressionResolver, - clusterService + clusterService, + context ); } @@ -147,21 +152,21 @@ public CompositeRetriever( * detector definition */ public PageIterator iterator() throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .gte(dataStartEpoch) .lt(dataEndEpoch) .format("epoch_millis"); - BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(anomalyDetector.getFilterQuery()).filter(rangeQuery); + BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(config.getFilterQuery()).filter(rangeQuery); // multiple categorical fields are supported CompositeAggregationBuilder composite = AggregationBuilders .composite( AGG_NAME_COMP, - anomalyDetector.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + config.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) ) .size(pageSize); - for (Feature feature : anomalyDetector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = ParseUtils .parseAggregators(feature.getAggregation().toString(), xContent, feature.getId()); composite.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); @@ -201,7 +206,7 @@ public void next(ActionListener listener) { // inject user role while searching. - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0]), source); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0]), source); final ActionListener searchResponseListener = new ActionListener() { @Override public void onResponse(SearchResponse response) { @@ -219,9 +224,9 @@ public void onFailure(Exception e) { .asyncRequestWithInjectedSecurity( searchRequest, client::search, - anomalyDetector.getId(), + config.getId(), client, - AnalysisType.AD, + context, searchResponseListener ); } @@ -291,7 +296,7 @@ private Page analyzePage(SearchResponse response) { } */ for (Bucket bucket : composite.getBuckets()) { - Optional featureValues = parseBucket(bucket, anomalyDetector.getEnabledFeatureIds()); + Optional featureValues = parseBucket(bucket, config.getEnabledFeatureIds()); // bucket.getKey() returns a map of categorical field like "host" and its value like "server_1" if (featureValues.isPresent() && bucket.getKey() != null) { results.put(Entity.createEntityByReordering(bucket.getKey()), featureValues.get()); @@ -335,7 +340,7 @@ Optional getComposite(SearchResponse response) { // such index // [blah]","index":"blah","resource.id":"blah","resource.type":"index_or_alias","index_uuid":"_na_"},"status":404}% if (response == null || response.getAggregations() == null) { - List sourceIndices = anomalyDetector.getIndices(); + List sourceIndices = config.getIndices(); String[] concreteIndices = indexNameExpressionResolver .concreteIndexNames(clusterService.state(), IndicesOptions.lenientExpandOpen(), sourceIndices.toArray(new String[0])); if (concreteIndices.length == 0) { diff --git a/src/main/java/org/opensearch/ad/feature/FeatureManager.java b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java similarity index 73% rename from src/main/java/org/opensearch/ad/feature/FeatureManager.java rename to src/main/java/org/opensearch/timeseries/feature/FeatureManager.java index 469f8707e..5fe777a17 100644 --- a/src/main/java/org/opensearch/ad/feature/FeatureManager.java +++ b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java @@ -9,10 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import static java.util.Arrays.copyOfRange; -import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; import java.io.IOException; import java.time.Clock; @@ -23,32 +22,41 @@ import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collections; import java.util.Deque; +import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.Queue; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import java.util.stream.DoubleStream; import java.util.stream.IntStream; import java.util.stream.LongStream; import java.util.stream.Stream; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.model.Forecaster; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.CleanState; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.dataprocessor.Imputer; -import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DataUtil; /** * A facade managing feature data operations and buffers. @@ -57,15 +65,13 @@ public class FeatureManager implements CleanState { private static final Logger logger = LogManager.getLogger(FeatureManager.class); - // Each anomaly detector has a queue of data points with timestamps (in epoch milliseconds). + // Each single-stream analysis has a queue of data points with timestamps (in epoch milliseconds). private final Map>>> detectorIdsToTimeShingles; private final SearchFeatureDao searchFeatureDao; - private final Imputer imputer; + public final Imputer imputer; private final Clock clock; - private final int maxTrainSamples; - private final int maxSampleStride; private final int trainSampleTimeRangeInHours; private final int minTrainSamples; private final double maxMissingPointsRate; @@ -82,8 +88,6 @@ public class FeatureManager implements CleanState { * @param searchFeatureDao DAO of features from search * @param imputer imputer of samples * @param clock clock for system time - * @param maxTrainSamples max number of samples from search - * @param maxSampleStride max stride between uninterpolated train samples * @param trainSampleTimeRangeInHours time range in hours for collect train samples * @param minTrainSamples min number of train samples * @param maxMissingPointsRate max proportion of shingle with missing points allowed to generate a shingle @@ -98,8 +102,6 @@ public FeatureManager( SearchFeatureDao searchFeatureDao, Imputer imputer, Clock clock, - int maxTrainSamples, - int maxSampleStride, int trainSampleTimeRangeInHours, int minTrainSamples, double maxMissingPointsRate, @@ -113,8 +115,6 @@ public FeatureManager( this.searchFeatureDao = searchFeatureDao; this.imputer = imputer; this.clock = clock; - this.maxTrainSamples = maxTrainSamples; - this.maxSampleStride = maxSampleStride; this.trainSampleTimeRangeInHours = trainSampleTimeRangeInHours; this.minTrainSamples = minTrainSamples; this.maxMissingPointsRate = maxMissingPointsRate; @@ -174,8 +174,35 @@ public void getCurrentFeatures(AnomalyDetector detector, long startTime, long en } } + public void getCurrentFeatures(Forecaster forecaster, long startTime, long endTime, ActionListener listener) { + List> missingRanges = Collections.singletonList(new SimpleImmutableEntry<>(startTime, endTime)); + try { + searchFeatureDao.getFeatureSamplesForPeriods(forecaster, missingRanges, AnalysisType.FORECAST, ActionListener.wrap(points -> { + // we only have one point + if (points.size() == 1) { + Optional point = points.get(0); + listener.onResponse(new SinglePointFeatures(point, Optional.empty())); + } else { + listener.onResponse(new SinglePointFeatures(Optional.empty(), Optional.empty())); + } + }, listener::onFailure)); + } catch (IOException e) { + listener.onFailure(new EndRunException(forecaster.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + } + } + + public void getCurrentFeatures(Config config, long startTime, long endTime, ActionListener listener) { + if (config instanceof AnomalyDetector) { + getCurrentFeatures((AnomalyDetector) config, startTime, endTime, listener); + } else if (config instanceof Forecaster) { + getCurrentFeatures((Forecaster) config, startTime, endTime, listener); + } else { + throw new UnsupportedOperationException(String.format(Locale.ROOT, "config type %s is not supported.", config.getClass())); + } + } + private List> getMissingRangesInShingle( - AnomalyDetector detector, + Config detector, Map>> featuresMap, long endTime ) { @@ -207,7 +234,7 @@ private List> getMissingRangesInShingle( * @param listener onResponse is called with unprocessed features and processed features for the current data point. */ private void updateUnprocessedFeatures( - AnomalyDetector detector, + Config detector, Deque>> shingle, Map>> featuresMap, long endTime, @@ -221,17 +248,19 @@ private void updateUnprocessedFeatures( listener.onResponse(getProcessedFeatures(shingle, detector, endTime)); } - private double[][] filterAndFill(Deque>> shingle, long endTime, AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); + private double[][] filterAndFill(Deque>> shingle, long endTime, Config config) { + double[][] result = null; + + int shingleSize = config.getShingleSize(); Deque>> filteredShingle = shingle .stream() .filter(e -> e.getValue().isPresent()) .collect(Collectors.toCollection(ArrayDeque::new)); - double[][] result = null; + if (filteredShingle.size() >= shingleSize - getMaxMissingPoints(shingleSize)) { // Imputes missing data points with the values of neighboring data points. - long maxMillisecondsDifference = maxNeighborDistance * detector.getIntervalInMilliseconds(); - result = getNearbyPointsForShingle(detector, filteredShingle, endTime, maxMillisecondsDifference) + long maxMillisecondsDifference = maxNeighborDistance * config.getIntervalInMilliseconds(); + result = getNearbyPointsForShingle(config, filteredShingle, endTime, maxMillisecondsDifference) .map(e -> e.getValue().getValue().orElse(null)) .filter(d -> d != null) .toArray(double[][]::new); @@ -240,6 +269,7 @@ private double[][] filterAndFill(Deque>> shingle, result = null; } } + return result; } @@ -254,7 +284,7 @@ private double[][] filterAndFill(Deque>> shingle, * point value. */ private Stream>>> getNearbyPointsForShingle( - AnomalyDetector detector, + Config detector, Deque>> shingle, long endTime, long maxMillisecondsDifference @@ -281,7 +311,7 @@ private LongStream getFullShingleEndTimes(long endTime, long intervalMilli, int } /** - * Returns to listener data for cold-start training. + * Returns to listener data for cold-start training. Used in AD single-stream. * * Training data starts with getting samples from (costly) search. * Samples are increased in dimension via shingling. @@ -292,27 +322,37 @@ private LongStream getFullShingleEndTimes(long endTime, long intervalMilli, int */ public void getColdStartData(AnomalyDetector detector, ActionListener> listener) { ActionListener> latestTimeListener = ActionListener - .wrap(latest -> getColdStartSamples(latest, detector, listener), listener::onFailure); + .wrap(latest -> getColdStartSamples(latest, detector, AnalysisType.AD, listener), listener::onFailure); searchFeatureDao - .getLatestDataTime(detector, new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, latestTimeListener, false)); + .getLatestDataTime( + detector, + Optional.empty(), + AnalysisType.AD, + new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, latestTimeListener, false) + ); } - private void getColdStartSamples(Optional latest, AnomalyDetector detector, ActionListener> listener) { - int shingleSize = detector.getShingleSize(); + private void getColdStartSamples( + Optional latest, + Config config, + AnalysisType context, + ActionListener> listener + ) { + int shingleSize = config.getShingleSize(); if (latest.isPresent()) { - List> sampleRanges = getColdStartSampleRanges(detector, latest.get()); + List> sampleRanges = getColdStartSampleRanges(config, latest.get()); try { ActionListener>> getFeaturesListener = ActionListener .wrap(samples -> processColdStartSamples(samples, shingleSize, listener), listener::onFailure); searchFeatureDao .getFeatureSamplesForPeriods( - detector, + config, sampleRanges, - AnalysisType.AD, + context, new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, getFeaturesListener, false) ); } catch (IOException e) { - listener.onFailure(new EndRunException(detector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + listener.onFailure(new EndRunException(config.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); } } else { listener.onResponse(Optional.empty()); @@ -361,7 +401,7 @@ private Optional fillAndShingle(LinkedList> shingle return result; } - private List> getColdStartSampleRanges(AnomalyDetector detector, long endMillis) { + private List> getColdStartSampleRanges(Config detector, long endMillis) { long interval = detector.getIntervalInMilliseconds(); int numSamples = Math.max((int) (Duration.ofHours(this.trainSampleTimeRangeInHours).toMillis() / interval), this.minTrainSamples); return IntStream @@ -624,18 +664,12 @@ private List> getPreviewRanges(List> ranges, private Entry getPreviewFeatures(double[][] samples, int stride, int shingleSize) { Entry unprocessedAndProcessed = Optional .of(samples) - .map(m -> transpose(m)) .map(m -> imputer.impute(m, stride * (samples.length - 1) + 1)) - .map(m -> transpose(m)) .map(m -> new SimpleImmutableEntry<>(copyOfRange(m, shingleSize - 1, m.length), batchShingle(m, shingleSize))) .get(); return unprocessedAndProcessed; } - public double[][] transpose(double[][] matrix) { - return createRealMatrix(matrix).transpose().getData(); - } - private long truncateToMinute(long epochMillis) { return Instant.ofEpochMilli(epochMillis).truncatedTo(ChronoUnit.MINUTES).toEpochMilli(); } @@ -688,11 +722,7 @@ public SinglePointFeatures getShingledFeatureForHistoricalAnalysis( return getProcessedFeatures(shingle, detector, endTime); } - private SinglePointFeatures getProcessedFeatures( - Deque>> shingle, - AnomalyDetector detector, - long endTime - ) { + private SinglePointFeatures getProcessedFeatures(Deque>> shingle, Config detector, long endTime) { int shingleSize = detector.getShingleSize(); Optional currentPoint = shingle.peekLast().getValue(); return new SinglePointFeatures( @@ -705,4 +735,178 @@ private SinglePointFeatures getProcessedFeatures( ); } + /** + * Extract sample array from samples and currentUnprocessed. Impute if necessary. + * Whether to use the provided samples is subject to the time ordering of lastProcessed + * and the provided samples. We throw away unProcessedSamples or currentUnprocessed if + * they are older than lastProcessed. + * + * @param config analysis config. + * @param unProcessedSamples unprocessed Samples stored in memory + * @param lastProcessed Last processed sample. + * @param currentUnprocessed current unprocessed sample. + * @return Continuous samples with possible imputations. + */ + public Pair>, Sample> getContinuousSamples( + Config config, + Deque unProcessedSamples, + Sample lastProcessed, + Sample currentUnprocessed + ) { + Deque samples = new ArrayDeque<>(); + if (lastProcessed != null) { + if (unProcessedSamples != null && !unProcessedSamples.isEmpty()) { + Sample lastElement = unProcessedSamples.getLast(); + if (lastElement != null && lastElement.getDataEndTime().compareTo(lastProcessed.getDataEndTime()) > 0) { + samples.add(lastProcessed); + samples.addAll(unProcessedSamples); + } + } else { + samples.add(lastProcessed); + } + } + + if (currentUnprocessed != null) { + samples.add(currentUnprocessed); + } + + if (samples.isEmpty()) { + return Pair.of(new ArrayList<>(), lastProcessed); + } else { + return removeLastSeenSample(getContinuousSamples(config, samples), lastProcessed); + } + } + + /** + * Remove the first sample since it is used before and included for interpolation's purpose + * @param res input data and sample pair + * @param previousLastSeenSample Last seen sample + * @return input without the first sample we have seen + */ + private Pair>, Sample> removeLastSeenSample( + Pair>, Sample> res, + Sample previousLastSeenSample + ) { + List> values = res.getKey(); + + if (previousLastSeenSample != null && values.size() > 1) { + return Pair.of(values.subList(1, values.size()), res.getValue()); + } else if (values.size() > 0) { + return Pair.of(values, res.getValue()); + } + + return Pair.of(new ArrayList<>(), res.getValue()); + } + + /** + * Extract samples from the input queue. Impute if necessary. + * + * @param config analysis config. + * @param samples Samples accumulated from previous job runs. + * @return Continuous samples with possible imputation and last seen sample. When samples is empty, return + * empty double array and empty last seen sample. + */ + public Pair>, Sample> getContinuousSamples(Config config, Queue samples) { + // To allow for small time variations/delays in running the config. + long maxMillisecondsDifference = config.getIntervalInMilliseconds() / 2; + + TreeMap search = new TreeMap<>(); + // Iterate over the sample queue using an Iterator. + // The Iterator interface provides a way to iterate over the elements of a queue + // in FIFO order + Iterator iterator = samples.iterator(); + long startTimeMillis = 0; + Sample lastElement = null; + while (iterator.hasNext()) { + lastElement = iterator.next(); + long dataEndTimeMillis = lastElement.getDataEndTime().toEpochMilli(); + if (startTimeMillis == 0) { + startTimeMillis = dataEndTimeMillis; + } + double[] valueList = lastElement.getValueList(); + search.put(dataEndTimeMillis, valueList); + } + + if (startTimeMillis == 0 || lastElement == null) { + return Pair.of(new ArrayList<>(), new Sample()); + } + + long endTimeMillis = lastElement.getDataEndTime().toEpochMilli(); + + // There can be small time variations/delays in running the analysis. + // Training data adjusted using end time and interval so that the end time + // of each sample has equal distance. This would help finding the missing + // data range and apply imputation. + // A map of entries, where the key is the computed millisecond timestamp + // associated with an interval in the training data, and the value is an entry + // that contains the actual timestamp of the data point and an optional data + // point value. + List adjustedDataEndTime = getFullTrainingDataEndTimes(endTimeMillis, config.getIntervalInMilliseconds(), startTimeMillis); + Map> adjustedTrainingData = adjustedDataEndTime.stream().map(t -> { + Optional> after = Optional.ofNullable(search.ceilingEntry(t)); + Optional> before = Optional.ofNullable(search.floorEntry(t)); + return after + .filter(a -> Math.abs(t - a.getKey()) <= before.map(b -> Math.abs(t - b.getKey())).orElse(Long.MAX_VALUE)) + .map(Optional::of) + .orElse(before) + // training data not within the max difference range will be filtered out and the corresponding t is Optional.empty and + // later filtered out as well + .filter(e -> Math.abs(t - e.getKey()) < maxMillisecondsDifference) + .map(e -> new SimpleImmutableEntry<>(t, e)); + }) + .filter(Optional::isPresent) + .map(Optional::get) + .collect( + Collectors + .toMap( + Entry::getKey, // Key mapper + Entry::getValue, // Value mapper + (v1, v2) -> v1, // Merge function + TreeMap::new + ) // Map implementation to order the entries by key value + ); + + // convert from long to int as we don't expect a huge number of samples + int totalNumSamples = adjustedDataEndTime.size(); + int numEnabledFeatures = config.getEnabledFeatureIds().size(); + double[][] trainingData = new double[totalNumSamples][numEnabledFeatures]; + + Iterator adjustedEndTimeIterator = adjustedDataEndTime.iterator(); + for (int index = 0; adjustedEndTimeIterator.hasNext(); index++) { + long time = adjustedEndTimeIterator.next(); + Entry entry = adjustedTrainingData.get(time); + if (entry != null) { + // the order of the elements in the Stream is the same as the order of the elements in the List entry.getValue() + trainingData[index] = entry.getValue(); + } else { + // create an array of Double.NaN + trainingData[index] = DoubleStream.generate(() -> Double.NaN).limit(numEnabledFeatures).toArray(); + } + } + + Imputer imputer = config.getImputer(); + double[][] imputedData = DataUtil.ltrim(imputer.impute(trainingData, totalNumSamples)); + List> imputedDataWithDataEndTime = new ArrayList<>(); + adjustedEndTimeIterator = adjustedDataEndTime.iterator(); + for (int index = 0; adjustedEndTimeIterator.hasNext(); index++) { + long dataEndTime = adjustedEndTimeIterator.next(); + imputedDataWithDataEndTime.add(new SimpleImmutableEntry<>(dataEndTime, imputedData[index])); + } + return Pair.of(imputedDataWithDataEndTime, lastElement); + } + + /** + * + * @param endTime End time of the stream + * @param intervalMilli interval between returned time + * @param startTime Start time of the stream + * @return a list of epoch timestamps from endTime with interval intervalMilli. The stream should stop when the number is earlier than startTime. + */ + private List getFullTrainingDataEndTimes(long endTime, long intervalMilli, long startTime) { + return LongStream + .iterate(startTime, i -> i + intervalMilli) + .takeWhile(i -> i <= endTime) + .boxed() // Convert LongStream to Stream + .collect(Collectors.toList()); // Collect to List + } } diff --git a/src/main/java/org/opensearch/ad/feature/Features.java b/src/main/java/org/opensearch/timeseries/feature/Features.java similarity index 98% rename from src/main/java/org/opensearch/ad/feature/Features.java rename to src/main/java/org/opensearch/timeseries/feature/Features.java index de347b78f..13cefc1d8 100644 --- a/src/main/java/org/opensearch/ad/feature/Features.java +++ b/src/main/java/org/opensearch/timeseries/feature/Features.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Arrays; import java.util.List; diff --git a/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java index 1ce44472f..4cdd5548b 100644 --- a/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java +++ b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java @@ -38,7 +38,6 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.feature.AbstractRetriever; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -79,11 +78,10 @@ * DAO for features from search. */ public class SearchFeatureDao extends AbstractRetriever { + private static final Logger logger = LogManager.getLogger(SearchFeatureDao.class); - protected static final String AGG_NAME_MIN = "min_timefield"; protected static final String AGG_NAME_TOP = "top_agg"; - - private static final Logger logger = LogManager.getLogger(SearchFeatureDao.class); + protected static final String AGG_NAME_MIN = "min_timefield"; // Dependencies private final Client client; @@ -166,14 +164,23 @@ public SearchFeatureDao( /** * Returns to listener the epoch time of the latset data under the detector. * - * @param detector info about the data + * @param config info about the data * @param listener onResponse is called with the epoch time of the latset data under the detector */ - public void getLatestDataTime(AnomalyDetector detector, ActionListener> listener) { + public void getLatestDataTime(Config config, Optional entity, AnalysisType context, ActionListener> listener) { + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); + + if (entity.isPresent()) { + for (TermQueryBuilder term : entity.get().getTermQueryForCustomerIndex()) { + internalFilterQuery.filter(term); + } + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) + .query(internalFilterQuery) + .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(config.getTimeField())) .size(0); - SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); final ActionListener searchResponseListener = ActionListener .wrap(response -> listener.onResponse(ParseUtils.getLatestDataTime(response)), listener::onFailure); // using the original context in listener as user roles have no permissions for internal operations like fetching a @@ -182,9 +189,9 @@ public void getLatestDataTime(AnomalyDetector detector, ActionListenerasyncRequestWithInjectedSecurity( searchRequest, client::search, - detector.getId(), + config.getId(), client, - AnalysisType.AD, + context, searchResponseListener ); } @@ -484,7 +491,7 @@ public void getMinDataTime(Config config, Optional entity, AnalysisType BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); if (entity.isPresent()) { - for (TermQueryBuilder term : entity.get().getTermQueryBuilders()) { + for (TermQueryBuilder term : entity.get().getTermQueryForCustomerIndex()) { internalFilterQuery.filter(term); } } diff --git a/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java b/src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java similarity index 97% rename from src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java rename to src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java index cbd7ef78b..9849a67f8 100644 --- a/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java +++ b/src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Optional; diff --git a/src/main/java/org/opensearch/timeseries/function/ResponseTransformer.java b/src/main/java/org/opensearch/timeseries/function/ResponseTransformer.java new file mode 100644 index 000000000..93c897718 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/function/ResponseTransformer.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.function; + +/** + * A functional interface for response transformation + * + * @param input type + * @param output type + */ +@FunctionalInterface +public interface ResponseTransformer { + R transform(T input); +} diff --git a/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java index 747f2bfac..717687437 100644 --- a/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java +++ b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java @@ -159,7 +159,7 @@ protected IndexManagement( this.threadPool = threadPool; this.clusterService.addLocalNodeClusterManagerListener(this); this.nodeFilter = nodeFilter; - this.settings = Settings.builder().put("index.hidden", true).build(); + this.settings = Settings.builder().put(IndexMetadata.SETTING_INDEX_HIDDEN, true).build(); this.maxUpdateRunningTimes = maxUpdateRunningTimes; this.indexType = indexType; this.maxPrimaryShards = maxPrimaryShards; @@ -261,7 +261,7 @@ protected void choosePrimaryShards(CreateIndexRequest request, boolean hiddenInd .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, getNumberOfPrimaryShards()) // 1 replica for better search performance and fail-over .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) - .put("index.hidden", hiddenIndex) + .put(IndexMetadata.SETTING_INDEX_HIDDEN, hiddenIndex) ); } @@ -501,7 +501,7 @@ public void initJobIndex(ActionListener actionListener) { // accordingly. // At least 1 replica for fail-over. .put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, minJobIndexReplicas + "-" + maxJobIndexReplicas) - .put("index.hidden", true) + .put(IndexMetadata.SETTING_INDEX_HIDDEN, true) ); adminClient.indices().create(request, actionListener); } catch (IOException e) { @@ -510,10 +510,42 @@ public void initJobIndex(ActionListener actionListener) { } } - public void validateCustomResultIndexAndExecute(String resultIndex, ExecutorFunction function, ActionListener listener) { + /** + * Validates the result index and executes the provided function. + * + *

+ * This method first checks if the mapping for the given result index is valid. If the mapping is not validated + * and is found to be invalid, the method logs a warning and notifies the listener of the failure. + *

+ * + *

+ * If the mapping is valid or has been previously validated, the method attempts to write and then immediately + * delete a dummy forecast result to the index. This is a workaround to verify the user's write permission on + * the custom result index, as there is currently no straightforward method to check for write permissions directly. + *

+ * + *

+ * If both write and delete operations are successful, the provided function is executed. If any step fails, + * the method logs an error and notifies the listener of the failure. + *

+ * + * @param The type of the action listener's response. + * @param resultIndex The custom result index to validate. + * @param function The function to be executed if validation is successful. + * @param mappingValidated Indicates whether the mapping for the result index has been previously validated. + * @param listener The listener to be notified of the success or failure of the operation. + * + * @throws IllegalArgumentException If the result index mapping is found to be invalid. + */ + public void validateResultIndexAndExecute( + String resultIndex, + ExecutorFunction function, + boolean mappingValidated, + ActionListener listener + ) { try { - if (!isValidResultIndexMapping(resultIndex)) { - logger.warn("Can't create detector with custom result index {} as its mapping is invalid", resultIndex); + if (!mappingValidated && !isValidResultIndexMapping(resultIndex)) { + logger.warn("Can't create analysis with custom result index {} as its mapping is invalid", resultIndex); listener.onFailure(new IllegalArgumentException(CommonMessages.INVALID_RESULT_INDEX_MAPPING + resultIndex)); return; } @@ -638,20 +670,20 @@ private void updateMappingIfNecessary(GroupedActionListener delegateListen updates.size() ); - for (IndexType adIndex : updates) { - logger.info(new ParameterizedMessage("Check [{}]'s mapping", adIndex.getIndexName())); - shouldUpdateIndex(adIndex, ActionListener.wrap(shouldUpdate -> { + for (IndexType index : updates) { + logger.info(new ParameterizedMessage("Check [{}]'s mapping", index.getIndexName())); + shouldUpdateIndex(index, ActionListener.wrap(shouldUpdate -> { if (shouldUpdate) { adminClient .indices() .putMapping( - new PutMappingRequest().indices(adIndex.getIndexName()).source(adIndex.getMapping(), XContentType.JSON), + new PutMappingRequest().indices(index.getIndexName()).source(index.getMapping(), XContentType.JSON), ActionListener.wrap(putMappingResponse -> { if (putMappingResponse.isAcknowledged()) { - logger.info(new ParameterizedMessage("Succeeded in updating [{}]'s mapping", adIndex.getIndexName())); - markMappingUpdated(adIndex); + logger.info(new ParameterizedMessage("Succeeded in updating [{}]'s mapping", index.getIndexName())); + markMappingUpdated(index); } else { - logger.error(new ParameterizedMessage("Fail to update [{}]'s mapping", adIndex.getIndexName())); + logger.error(new ParameterizedMessage("Fail to update [{}]'s mapping", index.getIndexName())); } conglomerateListeneer.onResponse(null); }, exception -> { @@ -659,7 +691,7 @@ private void updateMappingIfNecessary(GroupedActionListener delegateListen .error( new ParameterizedMessage( "Fail to update [{}]'s mapping due to [{}]", - adIndex.getIndexName(), + index.getIndexName(), exception.getMessage() ) ); @@ -670,14 +702,14 @@ private void updateMappingIfNecessary(GroupedActionListener delegateListen // index does not exist or the version is already up-to-date. // When creating index, new mappings will be used. // We don't need to update it. - logger.info(new ParameterizedMessage("We don't need to update [{}]'s mapping", adIndex.getIndexName())); - markMappingUpdated(adIndex); + logger.info(new ParameterizedMessage("We don't need to update [{}]'s mapping", index.getIndexName())); + markMappingUpdated(index); conglomerateListeneer.onResponse(null); } }, exception -> { logger .error( - new ParameterizedMessage("Fail to check whether we should update [{}]'s mapping", adIndex.getIndexName()), + new ParameterizedMessage("Fail to check whether we should update [{}]'s mapping", index.getIndexName()), exception ); conglomerateListeneer.onFailure(exception); @@ -746,7 +778,7 @@ public void initCustomResultIndexAndExecute(String resultIndex, ExecutorFunc initCustomResultIndexDirectly(resultIndex, ActionListener.wrap(response -> { if (response.isAcknowledged()) { logger.info("Successfully created result index {}", resultIndex); - validateCustomResultIndexAndExecute(resultIndex, function, listener); + validateResultIndexAndExecute(resultIndex, function, false, listener); } else { String error = "Creating result index with mappings call not acknowledged: " + resultIndex; logger.error(error); @@ -755,14 +787,14 @@ public void initCustomResultIndexAndExecute(String resultIndex, ExecutorFunc }, exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { // It is possible the index has been created while we sending the create request - validateCustomResultIndexAndExecute(resultIndex, function, listener); + validateResultIndexAndExecute(resultIndex, function, false, listener); } else { logger.error("Failed to create result index " + resultIndex, exception); listener.onFailure(exception); } })); } else { - validateCustomResultIndexAndExecute(resultIndex, function, listener); + validateResultIndexAndExecute(resultIndex, function, false, listener); } } @@ -788,10 +820,10 @@ public void validateCustomIndexForBackendJob( injectSecurity.close(); listener.onFailure(e); }); - validateCustomResultIndexAndExecute(resultIndex, () -> { + validateResultIndexAndExecute(resultIndex, () -> { injectSecurity.close(); function.execute(); - }, wrappedListener); + }, true, wrappedListener); } catch (Exception e) { logger.error("Failed to validate custom index for backend job " + securityLogId, e); listener.onFailure(e); @@ -872,7 +904,6 @@ public boolean isValidResultIndexMapping(String resultIndex) { return false; } LinkedHashMap mapping = (LinkedHashMap) indexMapping.get(propertyName); - boolean correctResultIndexMapping = true; for (String fieldName : RESULT_FIELD_CONFIGS.keySet()) { @@ -883,6 +914,7 @@ public boolean isValidResultIndexMapping(String resultIndex) { // feature_id={type=keyword}}}}} // if it is a map of map, Object.equals can compare them regardless of order if (!mapping.containsKey(fieldName) || !defaultSchema.equals(mapping.get(fieldName))) { + logger.warn("mapping mismatch due to {}", fieldName); correctResultIndexMapping = false; break; } diff --git a/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java b/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java new file mode 100644 index 000000000..29a2585dc --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java @@ -0,0 +1,361 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetAction; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.ScrollableHitSource; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.util.ClientUtil; + +import com.google.gson.Gson; + +import io.protostuff.LinkedBuffer; + +public abstract class CheckpointDao & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger logger = LogManager.getLogger(CheckpointDao.class); + public static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; + public static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; + public static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; + public static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; + public static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; + + // dependencies + protected final Client client; + protected final ClientUtil clientUtil; + + // configuration + protected final String indexName; + + protected Gson gson; + + // we won't read/write a checkpoint larger than a threshold + protected final int maxCheckpointBytes; + + protected final GenericObjectPool serializeRCFBufferPool; + protected final int serializeRCFBufferSize; + + protected final IndexManagement indexUtil; + protected final Clock clock; + public static final String NOT_ABLE_TO_DELETE_CHECKPOINT_MSG = "Cannot delete all checkpoints of detector"; + + public CheckpointDao( + Client client, + ClientUtil clientUtil, + String indexName, + Gson gson, + int maxCheckpointBytes, + GenericObjectPool serializeRCFBufferPool, + int serializeRCFBufferSize, + IndexManagementType indexUtil, + Clock clock + ) { + this.client = client; + this.clientUtil = clientUtil; + this.indexName = indexName; + this.gson = gson; + this.maxCheckpointBytes = maxCheckpointBytes; + this.serializeRCFBufferPool = serializeRCFBufferPool; + this.serializeRCFBufferSize = serializeRCFBufferSize; + this.indexUtil = indexUtil; + this.clock = clock; + } + + protected void putModelCheckpoint(String modelId, Map source, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + onCheckpointNotExist(source, modelId, listener); + } + } + + /** + * Update the model doc using fields in source. This ensures we won't touch + * the old checkpoint and nodes with old/new logic can coexist in a cluster. + * This is useful for introducing compact rcf new model format. + * + * @param source fields to update + * @param modelId model Id, used as doc id in the checkpoint index + * @param listener Listener to return response + */ + protected void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { + + UpdateRequest updateRequest = new UpdateRequest(indexName, modelId); + updateRequest.doc(source); + // If the document does not already exist, the contents of the upsert element are inserted as a new document. + // If the document exists, update fields in the map + updateRequest.docAsUpsert(true); + clientUtil + .asyncRequest( + updateRequest, + client::update, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + protected void onCheckpointNotExist(Map source, String modelId, ActionListener listener) { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + saveModelCheckpointAsync(source, modelId, listener); + + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + saveModelCheckpointAsync(source, modelId, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); + } + })); + } + + protected Map.Entry checkoutOrNewBuffer() { + LinkedBuffer buffer = null; + boolean isCheckout = true; + try { + buffer = serializeRCFBufferPool.borrowObject(); + } catch (Exception e) { + logger.warn("Failed to borrow a buffer from pool", e); + } + if (buffer == null) { + buffer = LinkedBuffer.allocate(serializeRCFBufferSize); + isCheckout = false; + } + return new SimpleImmutableEntry(buffer, isCheckout); + } + + /** + * Deletes the model checkpoint for the model. + * + * @param modelId id of the model + * @param listener onReponse is called with null when the operation is completed + */ + public void deleteModelCheckpoint(String modelId, ActionListener listener) { + clientUtil + .asyncRequest( + new DeleteRequest(indexName, modelId), + client::delete, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + protected void logFailure(BulkByScrollResponse response, String id) { + if (response.isTimedOut()) { + logger.warn(CheckpointDao.TIMEOUT_LOG_MSG + " {}", id); + } else if (!response.getBulkFailures().isEmpty()) { + logger.warn(CheckpointDao.BULK_FAILURE_LOG_MSG + " {}", id); + for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { + logger.warn(bulkFailure); + } + } else { + logger.warn(CheckpointDao.SEARCH_FAILURE_LOG_MSG + " {}", id); + for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { + logger.warn(searchFailure); + } + } + } + + /** + * Should we save the checkpoint or not + * @param lastCheckpointTIme Last checkpoint time + * @param forceWrite Save no matter what + * @param checkpointInterval Checkpoint interval + * @param clock UTC clock + * + * @return true when forceWrite is true or we haven't saved checkpoint in the + * last checkpoint interval; false otherwise + */ + public boolean shouldSave(Instant lastCheckpointTIme, boolean forceWrite, Duration checkpointInterval, Clock clock) { + return (lastCheckpointTIme != Instant.MIN && lastCheckpointTIme.plus(checkpointInterval).isBefore(clock.instant())) || forceWrite; + } + + public void batchWrite(BulkRequest request, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + // create index failure. Notify callers using listener. + listener.onFailure(new TimeSeriesException("Creating checkpoint with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); + listener.onFailure(exception); + } + })); + } + } + + /** + * Serialized samples + * @param samples input samples + * @return serialized object + */ + protected Optional toCheckpoint(Queue samples) { + if (samples == null) { + return Optional.empty(); + } + return Optional.of(samples.toArray()); + } + + public void batchRead(MultiGetRequest request, ActionListener listener) { + clientUtil.execute(MultiGetAction.INSTANCE, request, listener); + } + + public void read(GetRequest request, ActionListener listener) { + clientUtil.execute(GetAction.INSTANCE, request, listener); + } + + /** + * Delete checkpoints associated with a config. Used in multi-entity detector. + * @param configId Config Id + */ + public void deleteModelCheckpointByConfigId(String configId) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = createDeleteCheckpointRequest(configId); + logger.info("Delete checkpoints of config {}", configId); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, configId); + } + // can return 0 docs get deleted because: + // 1) we cannot find matching docs + // 2) bad stats from OpenSearch. In this case, docs are deleted, but + // OpenSearch says deleted is 0. + logger.info("{} " + CheckpointDao.DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(CheckpointDao.INDEX_DELETED_LOG_MSG + " {}", configId); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_CHECKPOINT_MSG, exception); + } + })); + } + + protected Optional> processRawCheckpoint(GetResponse response) { + return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public ModelState processHCGetResponse(GetResponse response, String modelId, String configId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromEntityModelCheckpoint(checkpointString.get(), modelId, configId); + } else { + return null; + } + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public ModelState processSingleStreamGetResponse(GetResponse response, String modelId, String configId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromSingleStreamModelCheckpoint(checkpointString.get(), modelId, configId); + } else { + return null; + } + } + + protected abstract ModelState fromEntityModelCheckpoint(Map checkpoint, String modelId, String configId); + + protected abstract ModelState fromSingleStreamModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ); + + public abstract Map toIndexSource(ModelState modelState) throws IOException; + + protected abstract DeleteByQueryRequest createDeleteCheckpointRequest(String configId); + + protected Deque loadSampleQueue(Map checkpoint, String modelId) { + Deque sampleQueue = new ArrayDeque<>(); + List> samples = (List>) checkpoint.get(CommonName.ENTITY_SAMPLE_QUEUE); + if (samples != null) { + for (int i = 0; i < samples.size(); i++) { + try { + Sample sample = Sample.extractSample(samples.get(i)); + if (sample != null) { + sampleQueue.add(sample); + } + } catch (Exception e) { + logger.warn("Exception while deserializing samples for " + modelId, e); + } + } + } + // can be null when checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return sampleQueue; + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java new file mode 100644 index 000000000..b477f454a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A customized ConcurrentHashMap that can automatically consume and release memory. + * This enables minimum change to our single-stream code as we just have to replace + * the map implementation. + * + * Note: this is mainly used for single-stream configs. The key is model id. + */ +public class MemoryAwareConcurrentHashmap extends + ConcurrentHashMap> { + protected final MemoryTracker memoryTracker; + + public MemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) { + this.memoryTracker = memoryTracker; + } + + @Override + public ModelState remove(Object key) { + ModelState deletedModelState = super.remove(key); + if (deletedModelState != null && deletedModelState.getModel().isPresent()) { + long memoryToRelease = memoryTracker.estimateTRCFModelSize(deletedModelState.getModel().get()); + memoryTracker.releaseMemory(memoryToRelease, true, Origin.REAL_TIME_DETECTOR); + } + return deletedModelState; + } + + @Override + public ModelState put(String key, ModelState value) { + ModelState previousAssociatedState = super.put(key, value); + if (value != null && value.getModel().isPresent()) { + long memoryToConsume = memoryTracker.estimateTRCFModelSize(value.getModel().get()); + memoryTracker.consumeMemory(memoryToConsume, true, Origin.REAL_TIME_DETECTOR); + } + return previousAssociatedState; + } + + /** + * Gets all of a config's model sizes hosted on a node + * + * @param configId config Id + * @return a map of model id to its memory size + */ + public Map getModelSize(String configId) { + Map res = new HashMap<>(); + super.entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) + .forEach(entry -> { + Optional modelOptional = entry.getValue().getModel(); + if (modelOptional.isPresent()) { + res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get())); + } + }); + return res; + } + + /** + * Checks if a model exists for the given config. + * @param configId Config Id + * @return `true` if the model exists, `false` otherwise. + */ + public boolean doesModelExist(String configId) { + return super.entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) + .anyMatch(n -> true); + } + + public boolean hostIfPossible(String modelId, ModelState toUpdate) { + return Optional + .ofNullable(toUpdate) + .filter(state -> state.getModel().isPresent()) + .filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get())) + .map(state -> { + super.put(modelId, toUpdate); + return true; + }) + .orElse(false); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java new file mode 100644 index 000000000..8b690c910 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java @@ -0,0 +1,632 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.util.Throwables; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.CleanState; +import org.opensearch.timeseries.MaintenanceState; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.caching.DoorKeeper; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.ExceptionUtil; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * The class bootstraps a model by performing a cold start + */ +public abstract class ModelColdStart & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker> + implements + MaintenanceState, + CleanState { + private static final Logger logger = LogManager.getLogger(ModelColdStart.class); + + private final Duration modelTtl; + + // A bloom filter checked before cold start to ensure we don't repeatedly + // retry cold start of the same model. + // keys are detector ids. + protected Map doorKeepers; + protected Instant lastThrottledColdStartTime; + protected int coolDownMinutes; + protected final Clock clock; + protected final ThreadPool threadPool; + protected final int numMinSamples; + protected CheckpointWriteWorkerType checkpointWriteWorker; + // make sure rcf use a specific random seed. Otherwise, we will use a random random (not a typo) seed. + // this is mainly used for testing to make sure the model we trained and the reference rcf produce + // the same results + protected final long rcfSeed; + protected final int numberOfTrees; + protected final int rcfSampleSize; + protected final double thresholdMinPvalue; + protected final double rcfTimeDecay; + protected final double initialAcceptFraction; + protected final NodeStateManager nodeStateManager; + protected final int defaulStrideLength; + protected final int defaultNumberOfSamples; + protected final SearchFeatureDao searchFeatureDao; + protected final FeatureManager featureManager; + protected final int maxRoundofColdStart; + protected final String threadPoolName; + protected final AnalysisType context; + + public ModelColdStart( + Duration modelTtl, + int coolDownMinutes, + Clock clock, + ThreadPool threadPool, + int numMinSamples, + CheckpointWriteWorkerType checkpointWriteWorker, + long rcfSeed, + int numberOfTrees, + int rcfSampleSize, + double thresholdMinPvalue, + double rcfTimeDecay, + NodeStateManager nodeStateManager, + int defaultSampleStride, + int defaultTrainSamples, + SearchFeatureDao searchFeatureDao, + FeatureManager featureManager, + int maxRoundofColdStart, + String threadPoolName, + AnalysisType context + ) { + this.modelTtl = modelTtl; + this.coolDownMinutes = coolDownMinutes; + this.clock = clock; + this.threadPool = threadPool; + this.numMinSamples = numMinSamples; + this.checkpointWriteWorker = checkpointWriteWorker; + this.rcfSeed = rcfSeed; + this.numberOfTrees = numberOfTrees; + this.rcfSampleSize = rcfSampleSize; + this.thresholdMinPvalue = thresholdMinPvalue; + this.rcfTimeDecay = rcfTimeDecay; + + this.doorKeepers = new ConcurrentHashMap<>(); + this.lastThrottledColdStartTime = Instant.MIN; + this.initialAcceptFraction = numMinSamples * 1.0d / rcfSampleSize; + + this.nodeStateManager = nodeStateManager; + this.defaulStrideLength = defaultSampleStride; + this.defaultNumberOfSamples = defaultTrainSamples; + this.searchFeatureDao = searchFeatureDao; + this.featureManager = featureManager; + this.maxRoundofColdStart = maxRoundofColdStart; + this.threadPoolName = threadPoolName; + this.context = context; + } + + @Override + public void maintenance() { + doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { + String id = doorKeeperEntry.getKey(); + DoorKeeper doorKeeper = doorKeeperEntry.getValue(); + if (doorKeeper.expired(modelTtl)) { + doorKeepers.remove(id); + } else { + doorKeeper.maintenance(); + } + }); + } + + @Override + public void clear(String id) { + doorKeepers.remove(id); + } + + /** + * Train models + * @param coldStartRequest cold start request + * @param configId Config Id + * @param modelState Model state + * @param listener callback before the method returns whenever ColdStarter + * finishes training or encounters exceptions. The listener helps notify the + * cold start queue to pull another request (if any) to execute. + */ + public void trainModel( + FeatureRequest coldStartRequest, + String configId, + ModelState modelState, + ActionListener>> listener + ) { + nodeStateManager.getConfig(configId, context, ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + logger.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); + listener.onFailure(new TimeSeriesException(configId, "fail to find config")); + return; + } + + Config config = configOptional.get(); + + String modelId = modelState.getModelId(); + + if (modelState.getSamples().size() < this.numMinSamples) { + // we cannot get last RCF score since cold start happens asynchronously + coldStart(modelId, coldStartRequest, modelState, config, listener); + } else { + try { + trainModelFromExistingSamples(modelState, coldStartRequest.getEntity(), config, coldStartRequest.getTaskId()); + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); + } + } + }, listener::onFailure)); + } + + public void trainModelFromExistingSamples(ModelState modelState, Optional entity, Config config, String taskId) { + Pair>, Sample> continuousSamples = featureManager.getContinuousSamples(config, modelState.getSamples()); + trainModelFromDataSegments(continuousSamples, entity, modelState, config, taskId); + } + + /** + * Training model + * @param modelId model Id corresponding to the entity + * @param coldStartRequest cold start request + * @param modelState model state + * @param config config accessor + * @param listener call back to call after cold start + */ + private void coldStart( + String modelId, + FeatureRequest coldStartRequest, + ModelState modelState, + Config config, + ActionListener>> listener + ) { + logger.debug("Trigger cold start for {}", modelId); + + if (modelState == null) { + listener.onFailure(new IllegalArgumentException(String.format(Locale.ROOT, "Cannot have empty model state"))); + return; + } + + if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { + listener.onResponse(null); + return; + } + + String configId = config.getId(); + boolean earlyExit = true; + try { + // Won't retry real-time cold start within 60 intervals for an entity + // coldStartRequest.getTaskId() == null in real-time cold start + + DoorKeeper doorKeeper = doorKeepers.computeIfAbsent(configId, id -> { + // reset every 60 intervals + return new DoorKeeper( + TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock, + TimeSeriesSettings.DOOR_KEEPER_COUNT_THRESHOLD + ); + }); + + if (doorKeeper.appearsMoreThanThreshold(modelId)) { + return; + } + + doorKeeper.put(modelId); + + ActionListener>, Sample>> coldStartCallBack = ActionListener.wrap(trainingData -> { + try { + List> endTimeToDataList = null; + if (trainingData != null && trainingData.getKey() != null) { + List> dataPoints = trainingData.getKey(); + // only train models if we have enough samples + if (dataPoints.size() >= numMinSamples) { + // The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by + // multiple places, so I want to make the saving model implicit just in case I forgot. + endTimeToDataList = trainModelFromDataSegments( + trainingData, + coldStartRequest.getEntity(), + modelState, + config, + coldStartRequest.getTaskId() + ); + logger.info("Succeeded in training entity: {}", modelId); + } else { + // save to checkpoint + checkpointWriteWorker.write(modelState, true, RequestPriority.MEDIUM); + logger.info("Not enough data to train model: {}, currently we have {}", modelId, dataPoints.size()); + } + } else { + logger.info("Cannot get training data for {}", modelId); + } + listener.onResponse(endTimeToDataList); + } catch (Exception e) { + listener.onFailure(e); + } + }, exception -> { + try { + logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); + Throwable cause = Throwables.getRootCause(exception); + if (ExceptionUtil.isOverloaded(cause)) { + logger.error("too many requests"); + lastThrottledColdStartTime = Instant.now(); + } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) { + // e.g., cannot find anomaly detector + nodeStateManager.setException(configId, exception); + } else { + nodeStateManager.setException(configId, new TimeSeriesException(configId, cause)); + } + listener.onFailure(exception); + } catch (Exception e) { + listener.onFailure(e); + } + }); + + threadPool + .executor(threadPoolName) + .execute( + () -> getColdStartData( + configId, + coldStartRequest, + config.getImputer(), + new ThreadedActionListener<>(logger, threadPool, threadPoolName, coldStartCallBack, false) + ) + ); + earlyExit = false; + } finally { + if (earlyExit) { + listener.onResponse(null); + } + } + } + + /** + * Get training data for an entity. + * + * We first note the maximum and minimum timestamp, and sample at most 24 points + * (with 60 points apart between two neighboring samples) between those minimum + * and maximum timestamps. Samples can be missing. We only interpolate points + * between present neighboring samples. We then transform samples and interpolate + * points to shingles. Finally, full shingles will be used for cold start. + * + * @param configId config Id + * @param coldStartRequest cold start request + * @param imputer imputation utility + * @param listener listener to return training data + */ + private void getColdStartData( + String configId, + FeatureRequest coldStartRequest, + Imputer imputer, + ActionListener>, Sample>> listener + ) { + ActionListener> getDetectorListener = ActionListener.wrap(configOp -> { + if (!configOp.isPresent()) { + listener.onFailure(new EndRunException(configId, "Config is not available.", false)); + return; + } + Config config = configOp.get(); + + ActionListener> minTimeListener = ActionListener.wrap(earliest -> { + if (earliest.isPresent()) { + long startTimeMs = earliest.get().longValue(); + + // End time uses milliseconds as start time is assumed to be in milliseconds. + // Opensearch uses a set of preconfigured formats to recognize and parse these + // strings into a long value + // representing milliseconds-since-the-epoch in UTC. + // More on https://tinyurl.com/wub4fk92 + // also, since we want to use current feature to score, we don't use current interval + // [current start, current end] for training. So we fetch training data ending at current start + long endTimeMs = coldStartRequest.getDataStartTimeMillis(); + Pair params = selectRangeParam(config); + int stride = params.getLeft(); + int numberOfSamples = params.getRight(); + + // we start with round 0 + getFeatures( + listener, + 0, + Pair.of(new ArrayList<>(), new Sample()), + config, + coldStartRequest.getEntity(), + stride, + numberOfSamples, + startTimeMs, + endTimeMs, + imputer + ); + } else { + listener.onResponse(Pair.of(new ArrayList<>(), new Sample())); + } + }, listener::onFailure); + + searchFeatureDao + .getMinDataTime( + config, + coldStartRequest.getEntity(), + context, + new ThreadedActionListener<>(logger, threadPool, threadPoolName, minTimeListener, false) + ); + + }, listener::onFailure); + + nodeStateManager + .getConfig(configId, context, new ThreadedActionListener<>(logger, threadPool, threadPoolName, getDetectorListener, false)); + } + + /** + * Select strideLength and numberOfSamples, where stride is the number of intervals + * between two samples and trainSamples is training samples to fetch. If we disable + * interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples; + * + * Algorithm: + * + * delta is the length of the detector interval in minutes. + * + * 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 + * and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32 + * points — which is only good. This step tries to match data with hourly pattern. + * 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1. + * This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes + *(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits -- + * we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern. + * That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern. + * Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. + * @return the chosen strideLength and numberOfSamples + */ + private Pair selectRangeParam(Config config) { + int shingleSize = config.getShingleSize(); + if (isInterpolationInColdStartEnabled()) { + long delta = config.getIntervalInMinutes(); + + int strideLength = defaulStrideLength; + int numberOfSamples = defaultNumberOfSamples; + if (delta <= 30 && 60 % delta == 0) { + strideLength = (int) (60 / delta); + numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; + } else { + strideLength = 1; + numberOfSamples = shingleSize + numMinSamples; + } + return Pair.of(strideLength, numberOfSamples); + } else { + return Pair.of(1, shingleSize + numMinSamples); + } + + } + + private void getFeatures( + ActionListener>, Sample>> listener, + int round, + Pair>, Sample> lastRounddataSample, + Config config, + Optional entity, + int stride, + int numberOfSamples, + long startTimeMs, + long endTimeMs, + Imputer imputer + ) { + if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < config.getIntervalInMilliseconds()) { + listener.onResponse(lastRounddataSample); + return; + } + + // Create ranges in ascending where the last sample's end time is the given endTimeMs. + // Sample ranges are also in ascending order in Opensearch's response. + List> sampleRanges = getTrainSampleRanges(config, startTimeMs, endTimeMs, stride, numberOfSamples); + + if (sampleRanges.isEmpty()) { + listener.onResponse(lastRounddataSample); + return; + } + + ActionListener>> getFeaturelistener = ActionListener.wrap(featureSamples -> { + + if (featureSamples.size() != sampleRanges.size()) { + logger + .error( + "We don't expect different featureSample size {} and sample range size {}.", + featureSamples.size(), + sampleRanges.size() + ); + listener.onResponse(lastRounddataSample); + return; + } + + int totalNumSamples = featureSamples.size(); + int numEnabledFeatures = config.getEnabledFeatureIds().size(); + + if (totalNumSamples != sampleRanges.size()) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "length mismatch: totalNumSamples %d != time range length %d", + totalNumSamples, + sampleRanges.size() + ) + ); + } + double[][] trainingData = new double[totalNumSamples][numEnabledFeatures]; + + // featuresSamples are in ascending order of time. + for (int index = 0; index < featureSamples.size(); index++) { + Optional featuresOptional = featureSamples.get(index); + if (featuresOptional.isPresent()) { + // the order of the elements in the Stream is the same as the order of the elements in the List entry.getValue() + trainingData[index] = featuresOptional.get(); + } else { + // create an array of Double.NaN + trainingData[index] = DoubleStream.generate(() -> Double.NaN).limit(numEnabledFeatures).toArray(); + } + } + + double[][] currentRoundColdStartData = imputer.impute(trainingData, totalNumSamples); + + Pair>, Sample> concatenatedDataSample = null; + List> lastRoundColdStartData = lastRounddataSample.getKey(); + // make sure the following logic making sense via checking lastRoundFirstStartTime > 0 + if (lastRoundColdStartData != null && lastRoundColdStartData.size() > 0) { + int trainingDataLength = currentRoundColdStartData.length + lastRoundColdStartData.size(); + double[][] concatenated = new double[trainingDataLength][numEnabledFeatures]; + long[] dataEndTime = new long[trainingDataLength]; + int startIndex = 0; + for (int i = 0; i < currentRoundColdStartData.length; i++) { + dataEndTime[startIndex] = sampleRanges.get(i).getValue(); + concatenated[startIndex++] = currentRoundColdStartData[i]; + } + for (int i = 0; i < lastRoundColdStartData.size(); i++) { + Entry lastRoundEntry = lastRoundColdStartData.get(i); + dataEndTime[startIndex] = lastRoundEntry.getKey(); + concatenated[startIndex++] = lastRoundEntry.getValue(); + } + + trainingData = imputer.impute(concatenated, concatenated.length); + List> trainingDataToReturn = new ArrayList<>(); + for (int i = 0; i < trainingData.length; i++) { + trainingDataToReturn.add(new SimpleImmutableEntry<>(dataEndTime[i], concatenated[i])); + } + concatenatedDataSample = Pair.of(trainingDataToReturn, lastRounddataSample.getValue()); + } else { + List> currentRoundReturn = new ArrayList<>(); + for (int i = 0; i < currentRoundColdStartData.length; i++) { + // map of data end time to its value + currentRoundReturn.add(new SimpleImmutableEntry<>(sampleRanges.get(i).getValue(), currentRoundColdStartData[i])); + } + concatenatedDataSample = Pair + .of( + currentRoundReturn, + new Sample( + currentRoundColdStartData[currentRoundColdStartData.length - 1], + Instant.ofEpochMilli(endTimeMs - config.getIntervalInMilliseconds()), + Instant.ofEpochMilli(endTimeMs) + ) + ); + } + + // If the first round of probe provides (32+shingleSize) points (note that if S0 is + // missing or all Si​ for some i > N is missing then we would miss a lot of points. + // Otherwise we can issue another round of query — if there is any sample in the + // second round then we would have 32 + shingleSize points. If there is no sample + // in the second round then we should wait for real data. + if (currentRoundColdStartData.length >= config.getShingleSize() + numMinSamples || round + 1 >= maxRoundofColdStart) { + listener.onResponse(concatenatedDataSample); + } else { + // the earliest sample's start time is the endTimeMs of next round of probe. + long earliestSampleStartTime = sampleRanges.get(sampleRanges.size() - 1).getKey(); + getFeatures( + listener, + round + 1, + concatenatedDataSample, + config, + entity, + stride, + numberOfSamples, + startTimeMs, + earliestSampleStartTime, + imputer + ); + } + }, listener::onFailure); + + try { + searchFeatureDao + .getColdStartSamplesForPeriods( + config, + sampleRanges, + entity, + // Accept empty bucket. + // 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 + // is meaningless in some case, but 0 is also meaningful in some cases. It may be that the query defining the + // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do + // that, we will have issues with legitimate interpretations of 0. + true, + context, + new ThreadedActionListener<>(logger, threadPool, threadPoolName, getFeaturelistener, false) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Get train samples within a time range. + * + * @param config accessor to config + * @param startMilli range start + * @param endMilli range end + * @param stride the number of intervals between two samples + * @param numberOfSamples maximum training samples to fetch + * @return list of sample time ranges in ascending order + */ + private List> getTrainSampleRanges(Config config, long startMilli, long endMilli, int stride, int numberOfSamples) { + long bucketSize = ((IntervalTimeConfiguration) config.getInterval()).toDuration().toMillis(); + int numBuckets = (int) Math.floor((endMilli - startMilli) / (double) bucketSize); + // adjust if numStrides is more than the max samples + int numStrides = Math.min((int) Math.floor(numBuckets / (double) stride), numberOfSamples); + List> sampleRanges = Stream + .iterate(endMilli, i -> i - stride * bucketSize) + .limit(numStrides) + .map(time -> new SimpleImmutableEntry<>(time - bucketSize, time)) + .collect(Collectors.toList()); + + // Reverse the list to get time ranges in ascending order + Collections.reverse(sampleRanges); + + return sampleRanges; + } + + protected abstract List> trainModelFromDataSegments( + Pair>, Sample> dataPoints, + Optional entity, + ModelState state, + Config config, + String taskId + ); + + protected abstract boolean isInterpolationInColdStartEnabled(); +} diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelManager.java b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java new file mode 100644 index 000000000..56cedfca3 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java @@ -0,0 +1,196 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.time.Clock; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class ModelManager, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart> { + + private static final Logger LOG = LogManager.getLogger(ModelManager.class); + + public enum ModelType { + RCF("rcf"), + THRESHOLD("threshold"), + TRCF("trcf"), + RCFCASTER("rcf_caster"); + + private String name; + + ModelType(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + protected final int rcfNumTrees; + protected final int rcfNumSamplesInTree; + protected final double rcfTimeDecay; + protected final int rcfNumMinSamples; + protected ColdStarterType coldStarter; + protected MemoryTracker memoryTracker; + protected final Clock clock; + protected FeatureManager featureManager; + protected final CheckpointDaoType checkpointDao; + + public ModelManager( + int rcfNumTrees, + int rcfNumSamplesInTree, + double rcfTimeDecay, + int rcfNumMinSamples, + ColdStarterType coldStarter, + MemoryTracker memoryTracker, + Clock clock, + FeatureManager featureManager, + CheckpointDaoType checkpointDao + ) { + this.rcfNumTrees = rcfNumTrees; + this.rcfNumSamplesInTree = rcfNumSamplesInTree; + this.rcfTimeDecay = rcfTimeDecay; + this.rcfNumMinSamples = rcfNumMinSamples; + this.coldStarter = coldStarter; + this.memoryTracker = memoryTracker; + this.clock = clock; + this.featureManager = featureManager; + this.checkpointDao = checkpointDao; + } + + public IntermediateResultType getResult( + Sample sample, + ModelState modelState, + String modelId, + Optional entity, + Config config, + String taskId + ) { + IntermediateResultType result = createEmptyResult(); + if (modelState != null) { + Optional entityModel = modelState.getModel(); + + if (entityModel.isEmpty()) { + coldStarter.trainModelFromExistingSamples(modelState, entity, config, taskId); + } + + if (modelState.getModel().isPresent()) { + result = score(sample, modelId, modelState, config); + } else { + modelState.addSample(sample); + } + } + return result; + } + + public void clearModels(String detectorId, Map models, ActionListener listener) { + Iterator id = models.keySet().iterator(); + clearModelForIterator(detectorId, models, id, listener); + } + + protected void clearModelForIterator(String detectorId, Map models, Iterator idIter, ActionListener listener) { + if (idIter.hasNext()) { + String modelId = idIter.next(); + if (SingleStreamModelIdMapper.getConfigIdForModelId(modelId).equals(detectorId)) { + models.remove(modelId); + checkpointDao + .deleteModelCheckpoint( + modelId, + ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure) + ); + } else { + clearModelForIterator(detectorId, models, idIter, listener); + } + } else { + listener.onResponse(null); + } + } + + @SuppressWarnings("unchecked") + public IntermediateResultType score( + Sample sample, + String modelId, + ModelState modelState, + Config config + ) { + + IntermediateResultType result = createEmptyResult(); + Optional model = modelState.getModel(); + try { + if (model != null && model.isPresent()) { + RCFModelType rcfModel = model.get(); + + Pair>, Sample> dataSamplePair = featureManager + .getContinuousSamples(config, modelState.getSamples(), modelState.getLastProcessedSample(), sample); + + List> data = dataSamplePair.getKey(); + RCFDescriptor lastResult = null; + for (int i = 0; i < data.size(); i++) { + // we are sure that the process method will indeed return an instance of RCFDescriptor. + lastResult = (RCFDescriptor) rcfModel.process(data.get(i).getValue(), 0); + } + modelState.clearSamples(); + + if (lastResult != null) { + result = toResult(rcfModel.getForest(), lastResult); + } + + modelState.setLastProcessedSample(dataSamplePair.getValue()); + } + } catch (Exception e) { + LOG + .error( + new ParameterizedMessage( + "Fail to score for [{}]: model Id [{}], feature [{}]", + modelState.getEntity().isEmpty() ? modelState.getConfigId() : modelState.getEntity().get(), + modelId, + Arrays.toString(sample.getValueList()) + ), + e + ); + throw e; + } finally { + modelState.setLastUsedTime(clock.instant()); + } + return result; + } + + protected abstract IntermediateResultType createEmptyResult(); + + protected abstract IntermediateResultType toResult( + RandomCutForest forecast, + RCFDescriptor castDescriptor + ); +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelState.java b/src/main/java/org/opensearch/timeseries/ml/ModelState.java similarity index 58% rename from src/main/java/org/opensearch/ad/ml/ModelState.java rename to src/main/java/org/opensearch/timeseries/ml/ModelState.java index bb9050ecb..71ae73eb4 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelState.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelState.java @@ -9,92 +9,89 @@ * GitHub history for details. */ -package org.opensearch.ad.ml; +package org.opensearch.timeseries.ml; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Deque; import java.util.HashMap; import java.util.Map; +import java.util.Optional; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; -/** - * A ML model and states such as usage. - */ -public class ModelState implements ExpiringState { - +public class ModelState implements org.opensearch.timeseries.ExpiringState { public static String MODEL_TYPE_KEY = "model_type"; public static String LAST_USED_TIME_KEY = "last_used_time"; public static String LAST_CHECKPOINT_TIME_KEY = "last_checkpoint_time"; public static String PRIORITY_KEY = "priority"; - private T model; - private String modelId; - private String detectorId; - private String modelType; + + protected T model; + protected String modelId; + protected String configId; + protected String modelType; // time when the ML model was used last time - private Instant lastUsedTime; - private Instant lastCheckpointTime; - private Clock clock; - private float priority; + protected Instant lastUsedTime; + protected Instant lastCheckpointTime; + protected Clock clock; + protected float priority; + protected Sample lastProcessedSample; + protected Deque samples; + protected Optional entity; /** * Constructor. * * @param model ML model * @param modelId Id of model partition - * @param detectorId Id of detector this model partition is used for + * @param configId Id of analysis this model partition is used for * @param modelType type of model * @param clock UTC clock * @param priority Priority of the model state. Used in multi-entity detectors' cache. + * @param lastProcessedSample last processed sample. Used in interpolation. + * @param entity Entity info if this is a HC entity state + * @param samples existing samples that haven't been processed */ - public ModelState(T model, String modelId, String detectorId, String modelType, Clock clock, float priority) { + public ModelState( + T model, + String modelId, + String configId, + String modelType, + Clock clock, + float priority, + Sample lastProcessedSample, + Optional entity, + Deque samples + ) { this.model = model; this.modelId = modelId; - this.detectorId = detectorId; + this.configId = configId; this.modelType = modelType; this.lastUsedTime = clock.instant(); // this is inaccurate until we find the last checkpoint time from disk this.lastCheckpointTime = Instant.MIN; this.clock = clock; this.priority = priority; + this.lastProcessedSample = lastProcessedSample; + this.entity = entity; + this.samples = samples; } /** - * Create state with zero priority. Used in single-entity detector. + * Constructor. Used in single-stream analysis. * - * @param Model object's type - * @param model The actual model object - * @param modelId Model Id - * @param detectorId Detector Id - * @param modelType Model type like RCF model + * @param model ML model + * @param modelId Id of model partition + * @param configId Id of analysis this model partition is used for + * @param modelType type of model * @param clock UTC clock - * - * @return the created model state - */ - public static ModelState createSingleEntityModelState( - T model, - String modelId, - String detectorId, - String modelType, - Clock clock - ) { - return new ModelState<>(model, modelId, detectorId, modelType, clock, 0f); - } - - /** - * Returns the ML model. - * - * @return the ML model. + * @param lastProcessedSample last processed sample. Used in interpolation. */ - public T getModel() { - return this.model; - } - - public void setModel(T model) { - this.model = model; + public ModelState(T model, String modelId, String configId, String modelType, Clock clock, Sample lastProcessedSample) { + this(model, modelId, configId, modelType, clock, 0, lastProcessedSample, Optional.empty(), new ArrayDeque<>()); } /** @@ -106,15 +103,6 @@ public String getModelId() { return modelId; } - /** - * Gets the detectorID of the model - * - * @return detectorId associated with the model - */ - public String getId() { - return detectorId; - } - /** * Gets the type of the model * @@ -172,16 +160,90 @@ public void setPriority(float priority) { this.priority = priority; } + public Sample getLastProcessedSample() { + return lastProcessedSample; + } + + public void setLastProcessedSample(Sample lastProcessedSample) { + this.lastProcessedSample = lastProcessedSample; + } + + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); + } + + /** + * Gets the Config ID of the model + * + * @return the config id associated with the model + */ + public String getConfigId() { + return configId; + } + + /** + * In old checkpoint mapping, we don't have entity. It's fine we are missing + * entity as it is mostly used for debugging. + * @return entity + */ + public Optional getEntity() { + return entity; + } + + public Deque getSamples() { + return this.samples; + } + + public void addSample(Sample sample) { + if (this.samples == null) { + this.samples = new ArrayDeque<>(); + } + if (sample != null && sample.getValueList() != null && sample.getValueList().length != 0) { + this.samples.add(sample); + } + } + + /** + * Sets a model. + * + * @param model model instance + */ + public void setModel(T model) { + this.model = model; + } + + /** + * + * @return optional model. + */ + public Optional getModel() { + return Optional.ofNullable(this.model); + } + + public void clearSamples() { + if (samples != null) { + samples.clear(); + } + } + + public void clear() { + clearSamples(); + model = null; + lastProcessedSample = null; + } + /** * Gets the Model State as a map * * @return Map of ModelStates */ + @SuppressWarnings("serial") public Map getModelStateAsMap() { return new HashMap() { { put(CommonName.MODEL_ID_FIELD, modelId); - put(ADCommonName.DETECTOR_ID_KEY, detectorId); + put(CommonName.CONFIG_ID_KEY, configId); put(MODEL_TYPE_KEY, modelType); /* A stats API broadcasts requests to all nodes and renders node responses using toXContent. * @@ -195,18 +257,10 @@ public Map getModelStateAsMap() { if (lastCheckpointTime != Instant.MIN) { put(LAST_CHECKPOINT_TIME_KEY, lastCheckpointTime.toEpochMilli()); } - if (model != null && model instanceof EntityModel) { - EntityModel summary = (EntityModel) model; - if (summary.getEntity().isPresent()) { - put(CommonName.ENTITY_KEY, summary.getEntity().get().toStat()); - } + if (entity.isPresent()) { + put(CommonName.ENTITY_KEY, entity.get().toStat()); } } }; } - - @Override - public boolean expired(Duration stateTtl) { - return expired(lastUsedTime, stateTtl, clock.instant()); - } } diff --git a/src/main/java/org/opensearch/timeseries/ml/Sample.java b/src/main/java/org/opensearch/timeseries/ml/Sample.java new file mode 100644 index 000000000..ccc4c3dfe --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/Sample.java @@ -0,0 +1,113 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; + +public class Sample implements ToXContentObject { + private final double[] data; + private final Instant dataStartTime; + private final Instant dataEndTime; + + public Sample(double[] data, Instant dataStartTime, Instant dataEndTime) { + super(); + this.data = data; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + } + + // Invalid sample + public Sample() { + this.data = new double[0]; + this.dataStartTime = this.dataEndTime = Instant.MIN; + } + + public double[] getValueList() { + return data; + } + + public Instant getDataStartTime() { + return dataStartTime; + } + + public Instant getDataEndTime() { + return dataEndTime; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (data != null) { + xContentBuilder.array(CommonName.VALUE_LIST_FIELD, data); + } + if (dataStartTime != null) { + xContentBuilder.field(CommonName.DATA_START_TIME_FIELD, dataStartTime.toEpochMilli()); + } + if (dataEndTime != null) { + xContentBuilder.field(CommonName.DATA_END_TIME_FIELD, dataEndTime.toEpochMilli()); + } + return xContentBuilder.endObject(); + } + + /** + * Extract Sample fields out of a serialized Map, which is what we get from a get checkpoint call. + * @param map serialized sample. + * Example input map: + * Key: last_processed_sample, Value type: java.util.HashMap + * Key: data_end_time, Value type: java.lang.Long + * Value: 1695825364700, Type: java.lang.Long + * Key: data_start_time, Value type: java.lang.Long + * Value: 1695825304700, Type: java.lang.Long + * Key: value_list, Value type: java.util.ArrayList + * Item type: java.lang.Double + * Value: 8840.0, Type: java.lang.Double + * @return a Sample. + */ + public static Sample extractSample(Map map) { + // Extract and convert values from the map + Long dataEndTimeLong = (Long) map.get(CommonName.DATA_END_TIME_FIELD); + Long dataStartTimeLong = (Long) map.get(CommonName.DATA_START_TIME_FIELD); + List valueList = (List) map.get(CommonName.VALUE_LIST_FIELD); + + // Check if all required keys are present in the map + if (dataEndTimeLong == null && dataStartTimeLong == null && valueList == null) { + return null; + } + + // Convert List to double[] + double[] data = valueList.stream().mapToDouble(Double::doubleValue).toArray(); + + // Convert long to Instant + Instant dataEndTime = Instant.ofEpochMilli(dataEndTimeLong); + Instant dataStartTime = Instant.ofEpochMilli(dataStartTimeLong); + + // Create a new Sample object and return it + return new Sample(data, dataStartTime, dataEndTime); + } + + public boolean isInvalid() { + return dataStartTime.compareTo(Instant.MIN) == 0 || dataEndTime.compareTo(Instant.MIN) == 0; + } + + @Override + public String toString() { + return "Sample [data=" + Arrays.toString(data) + ", dataStartTime=" + dataStartTime + ", dataEndTime=" + dataEndTime + "]"; + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java b/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java index c33c4818f..cf045f79d 100644 --- a/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java +++ b/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java @@ -22,9 +22,10 @@ * */ public class SingleStreamModelIdMapper { - protected static final String DETECTOR_ID_PATTERN = "(.*)_model_.+"; + protected static final String CONFIG_ID_PATTERN = "(.*)_model_.+"; protected static final String RCF_MODEL_ID_PATTERN = "%s_model_rcf_%d"; protected static final String THRESHOLD_MODEL_ID_PATTERN = "%s_model_threshold"; + protected static final String CASTER_MODEL_ID_PATTERN = "%s_model_caster"; /** * Returns the model ID for the RCF model partition. @@ -48,14 +49,24 @@ public static String getThresholdModelId(String detectorId) { } /** - * Gets the detector id from the model id. + * Returns the model ID for the rcf caster model. + * + * @param forecasterId ID of the forecaster for which the model is trained + * @return ID for the forecaster model + */ + public static String getCasterModelId(String forecasterId) { + return String.format(Locale.ROOT, CASTER_MODEL_ID_PATTERN, forecasterId); + } + + /** + * Gets the config id from the model id. * * @param modelId id of a model * @return id of the detector the model is for * @throws IllegalArgumentException if model id is invalid */ - public static String getDetectorIdForModelId(String modelId) { - Matcher matcher = Pattern.compile(DETECTOR_ID_PATTERN).matcher(modelId); + public static String getConfigIdForModelId(String modelId) { + Matcher matcher = Pattern.compile(CONFIG_ID_PATTERN).matcher(modelId); if (matcher.matches()) { return matcher.group(1); } else { @@ -70,7 +81,7 @@ public static String getDetectorIdForModelId(String modelId) { * @return thresholding model Id */ public static String getThresholdModelIdFromRCFModelId(String rcfModelId) { - String detectorId = getDetectorIdForModelId(rcfModelId); + String detectorId = getConfigIdForModelId(rcfModelId); return getThresholdModelId(detectorId); } } diff --git a/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java b/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java new file mode 100644 index 000000000..960234701 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +public class TimeSeriesSingleStreamCheckpointDao { + +} diff --git a/src/main/java/org/opensearch/timeseries/model/Config.java b/src/main/java/org/opensearch/timeseries/model/Config.java index 15f67d116..fb15dcbb2 100644 --- a/src/main/java/org/opensearch/timeseries/model/Config.java +++ b/src/main/java/org/opensearch/timeseries/model/Config.java @@ -12,6 +12,7 @@ import java.time.Instant; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; @@ -40,6 +41,7 @@ import org.opensearch.timeseries.dataprocessor.PreviousValueImputer; import org.opensearch.timeseries.dataprocessor.ZeroImputer; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.owasp.encoder.Encode; import com.google.common.base.Objects; import com.google.common.collect.ImmutableList; @@ -71,6 +73,7 @@ public abstract class Config implements Writeable, ToXContentObject { public static final String USER_FIELD = "user"; public static final String RESULT_INDEX_FIELD = "result_index"; public static final String IMPUTATION_OPTION_FIELD = "imputation_option"; + public static final String TRANSFORM_DECAY_FIELD = "transform_decay"; private static final Imputer zeroImputer; private static final Imputer previousImputer; @@ -95,6 +98,7 @@ public abstract class Config implements Writeable, ToXContentObject { protected List categoryFields; protected User user; protected ImputationOption imputationOption; + protected Double transformDecay; // validation error protected String errorMessage; @@ -131,7 +135,8 @@ protected Config( User user, String resultIndex, TimeConfiguration interval, - ImputationOption imputationOption + ImputationOption imputationOption, + Double transformDecay ) { if (Strings.isBlank(name)) { errorMessage = CommonMessages.EMPTY_NAME; @@ -172,6 +177,18 @@ protected Config( return; } + if (transformDecay != null && (transformDecay <= 0 || transformDecay > 1)) { + issueType = ValidationIssueType.TRANSFORM_DECAY; + errorMessage = "transform decay has to be between 0 and 1"; + return; + } + + errorMessage = validateDescription(description); + if (errorMessage != null) { + issueType = ValidationIssueType.DESCRIPTION; + return; + } + this.id = id; this.version = version; this.name = name; @@ -193,6 +210,8 @@ protected Config( this.imputer = createImputer(); this.issueType = null; this.errorMessage = null; + // If transformDecay is null, use the default value from TimeSeriesSettings + this.transformDecay = Optional.ofNullable(transformDecay).orElse(TimeSeriesSettings.TIME_DECAY); } public Config(StreamInput input) throws IOException { @@ -227,6 +246,7 @@ public Config(StreamInput input) throws IOException { this.imputationOption = null; } this.imputer = createImputer(); + this.transformDecay = input.readDouble(); } /* @@ -275,6 +295,7 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } + output.writeDouble(transformDecay); } /** @@ -324,7 +345,8 @@ public boolean equals(Object o) { && Objects.equal(categoryFields, config.categoryFields) && Objects.equal(user, config.user) && Objects.equal(customResultIndex, config.customResultIndex) - && Objects.equal(imputationOption, config.imputationOption); + && Objects.equal(imputationOption, config.imputationOption) + && Objects.equal(transformDecay, config.transformDecay); } @Generated @@ -345,7 +367,8 @@ public int hashCode() { schemaVersion, user, customResultIndex, - imputationOption + imputationOption, + transformDecay ); } @@ -353,14 +376,15 @@ public int hashCode() { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder .field(NAME_FIELD, name) - .field(DESCRIPTION_FIELD, description) + .field(DESCRIPTION_FIELD, Encode.forHtml(description)) .field(TIMEFIELD_FIELD, timeField) .field(INDICES_FIELD, indices.toArray()) .field(FILTER_QUERY_FIELD, filterQuery) .field(WINDOW_DELAY_FIELD, windowDelay) .field(SHINGLE_SIZE_FIELD, shingleSize) .field(CommonName.SCHEMA_VERSION_FIELD, schemaVersion) - .field(FEATURE_ATTRIBUTES_FIELD, featureAttributes.toArray()); + .field(FEATURE_ATTRIBUTES_FIELD, featureAttributes.toArray()) + .field(TRANSFORM_DECAY_FIELD, transformDecay); if (uiMetadata != null && !uiMetadata.isEmpty()) { builder.field(UI_METADATA_FIELD, uiMetadata); @@ -505,6 +529,16 @@ public String validateCustomResultIndex(String resultIndex) { return null; } + public String validateDescription(String description) { + if (Strings.isEmpty(description)) { + return null; + } + if (description.length() > TimeSeriesSettings.MAX_DESCRIPTION_LENGTH) { + return CommonMessages.DESCRIPTION_LENGTH_TOO_LONG; + } + return null; + } + public static boolean isHC(List categoryFields) { return categoryFields != null && categoryFields.size() > 0; } @@ -521,10 +555,14 @@ public Imputer getImputer() { return imputer; } + public Double getTransformDecay() { + return transformDecay; + } + protected Imputer createImputer() { Imputer imputer = null; - // default interpolator is using last known value + // default imputer is using last known value if (imputationOption == null) { return previousImputer; } diff --git a/src/main/java/org/opensearch/timeseries/model/ConfigProfile.java b/src/main/java/org/opensearch/timeseries/model/ConfigProfile.java new file mode 100644 index 000000000..833f33765 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/ConfigProfile.java @@ -0,0 +1,453 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.TaskProfile; +import org.opensearch.timeseries.constant.CommonName; + +public abstract class ConfigProfile> + implements + Writeable, + ToXContentObject, + Mergeable { + + protected ConfigState state; + protected String error; + protected ModelProfileOnNode[] modelProfile; + protected int shingleSize; + protected String coordinatingNode; + protected long totalSizeInBytes; + protected InitProgressProfile initProgress; + protected Long totalEntities; + protected Long activeEntities; + protected TaskProfileType taskProfile; + protected long modelCount; + protected String taskName; + + public ConfigProfile(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.state = in.readEnum(ConfigState.class); + } + + this.error = in.readOptionalString(); + this.modelProfile = in.readOptionalArray(ModelProfileOnNode::new, ModelProfileOnNode[]::new); + this.shingleSize = in.readOptionalInt(); + this.coordinatingNode = in.readOptionalString(); + this.totalSizeInBytes = in.readOptionalLong(); + this.totalEntities = in.readOptionalLong(); + this.activeEntities = in.readOptionalLong(); + if (in.readBoolean()) { + this.initProgress = new InitProgressProfile(in); + } + if (in.readBoolean()) { + this.taskProfile = createTaskProfile(in); + } + this.modelCount = in.readVLong(); + } + + protected ConfigProfile() { + + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + public static abstract class Builder> { + protected ConfigState state = null; + protected String error = null; + protected ModelProfileOnNode[] modelProfile = null; + protected int shingleSize = -1; + protected String coordinatingNode = null; + protected long totalSizeInBytes = -1; + protected InitProgressProfile initProgress = null; + protected Long totalEntities; + protected Long activeEntities; + protected long modelCount = 0; + + public Builder() {} + + public Builder state(ConfigState state) { + this.state = state; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder modelProfile(ModelProfileOnNode[] modelProfile) { + this.modelProfile = modelProfile; + return this; + } + + public Builder modelCount(long modelCount) { + this.modelCount = modelCount; + return this; + } + + public Builder shingleSize(int shingleSize) { + this.shingleSize = shingleSize; + return this; + } + + public Builder coordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + return this; + } + + public Builder totalSizeInBytes(long totalSizeInBytes) { + this.totalSizeInBytes = totalSizeInBytes; + return this; + } + + public Builder initProgress(InitProgressProfile initProgress) { + this.initProgress = initProgress; + return this; + } + + public Builder totalEntities(Long totalEntities) { + this.totalEntities = totalEntities; + return this; + } + + public Builder activeEntities(Long activeEntities) { + this.activeEntities = activeEntities; + return this; + } + + public abstract Builder taskProfile(TaskProfileType taskProfile); + + public abstract > ConfigProfileType build(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (state == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(state); + } + + out.writeOptionalString(error); + out.writeOptionalArray(modelProfile); + out.writeOptionalInt(shingleSize); + out.writeOptionalString(coordinatingNode); + out.writeOptionalLong(totalSizeInBytes); + out.writeOptionalLong(totalEntities); + out.writeOptionalLong(activeEntities); + if (initProgress == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + initProgress.writeTo(out); + } + if (taskProfile == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + taskProfile.writeTo(out); + } + out.writeVLong(modelCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + + if (state != null) { + xContentBuilder.field(CommonName.STATE, state); + } + if (error != null) { + xContentBuilder.field(CommonName.ERROR, error); + } + if (modelProfile != null && modelProfile.length > 0) { + xContentBuilder.startArray(CommonName.MODELS); + for (ModelProfileOnNode profile : modelProfile) { + profile.toXContent(xContentBuilder, params); + } + xContentBuilder.endArray(); + } + if (shingleSize != -1) { + xContentBuilder.field(CommonName.SHINGLE_SIZE, shingleSize); + } + if (coordinatingNode != null && !coordinatingNode.isEmpty()) { + xContentBuilder.field(CommonName.COORDINATING_NODE, coordinatingNode); + } + if (totalSizeInBytes != -1) { + xContentBuilder.field(CommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); + } + if (initProgress != null) { + xContentBuilder.field(CommonName.INIT_PROGRESS, initProgress); + } + if (totalEntities != null) { + xContentBuilder.field(CommonName.TOTAL_ENTITIES, totalEntities); + } + if (activeEntities != null) { + xContentBuilder.field(CommonName.ACTIVE_ENTITIES, activeEntities); + } + if (taskProfile != null) { + xContentBuilder.field(getTaskFieldName(), taskProfile); + } + if (modelCount > 0) { + xContentBuilder.field(CommonName.MODEL_COUNT, modelCount); + } + return xContentBuilder.endObject(); + } + + public ConfigState getState() { + return state; + } + + public void setState(ConfigState state) { + this.state = state; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public ModelProfileOnNode[] getModelProfile() { + return modelProfile; + } + + public void setModelProfile(ModelProfileOnNode[] modelProfile) { + this.modelProfile = modelProfile; + } + + public int getShingleSize() { + return shingleSize; + } + + public void setShingleSize(int shingleSize) { + this.shingleSize = shingleSize; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public void setCoordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + } + + public long getTotalSizeInBytes() { + return totalSizeInBytes; + } + + public void setTotalSizeInBytes(long totalSizeInBytes) { + this.totalSizeInBytes = totalSizeInBytes; + } + + public InitProgressProfile getInitProgress() { + return initProgress; + } + + public void setInitProgress(InitProgressProfile initProgress) { + this.initProgress = initProgress; + } + + public Long getTotalEntities() { + return totalEntities; + } + + public void setTotalEntities(Long totalEntities) { + this.totalEntities = totalEntities; + } + + public Long getActiveEntities() { + return activeEntities; + } + + public void setActiveEntities(Long activeEntities) { + this.activeEntities = activeEntities; + } + + public TaskProfileType getTaskProfile() { + return taskProfile; + } + + public void setTaskProfile(TaskProfileType taskProfile) { + this.taskProfile = taskProfile; + } + + public long getModelCount() { + return modelCount; + } + + public void setModelCount(long modelCount) { + this.modelCount = modelCount; + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + ConfigProfile otherProfile = (ConfigProfile) other; + if (otherProfile.getState() != null) { + this.state = otherProfile.getState(); + } + if (otherProfile.getError() != null) { + this.error = otherProfile.getError(); + } + if (otherProfile.getCoordinatingNode() != null) { + this.coordinatingNode = otherProfile.getCoordinatingNode(); + } + if (otherProfile.getShingleSize() != -1) { + this.shingleSize = otherProfile.getShingleSize(); + } + if (otherProfile.getModelProfile() != null) { + this.modelProfile = otherProfile.getModelProfile(); + } + if (otherProfile.getTotalSizeInBytes() != -1) { + this.totalSizeInBytes = otherProfile.getTotalSizeInBytes(); + } + if (otherProfile.getInitProgress() != null) { + this.initProgress = otherProfile.getInitProgress(); + } + if (otherProfile.getTotalEntities() != null) { + this.totalEntities = otherProfile.getTotalEntities(); + } + if (otherProfile.getActiveEntities() != null) { + this.activeEntities = otherProfile.getActiveEntities(); + } + if (otherProfile.getTaskProfile() != null) { + this.taskProfile = otherProfile.getTaskProfile(); + } + if (otherProfile.getModelCount() > 0) { + this.modelCount = otherProfile.getModelCount(); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof ConfigProfile) { + ConfigProfile other = (ConfigProfile) obj; + + EqualsBuilder equalsBuilder = new EqualsBuilder(); + if (state != null) { + equalsBuilder.append(state, other.state); + } + if (error != null) { + equalsBuilder.append(error, other.error); + } + if (modelProfile != null && modelProfile.length > 0) { + equalsBuilder.append(modelProfile, other.modelProfile); + } + if (shingleSize != -1) { + equalsBuilder.append(shingleSize, other.shingleSize); + } + if (coordinatingNode != null) { + equalsBuilder.append(coordinatingNode, other.coordinatingNode); + } + if (totalSizeInBytes != -1) { + equalsBuilder.append(totalSizeInBytes, other.totalSizeInBytes); + } + if (initProgress != null) { + equalsBuilder.append(initProgress, other.initProgress); + } + if (totalEntities != null) { + equalsBuilder.append(totalEntities, other.totalEntities); + } + if (activeEntities != null) { + equalsBuilder.append(activeEntities, other.activeEntities); + } + if (taskProfile != null) { + equalsBuilder.append(taskProfile, other.taskProfile); + } + if (modelCount > 0) { + equalsBuilder.append(modelCount, other.modelCount); + } + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder() + .append(state) + .append(error) + .append(modelProfile) + .append(shingleSize) + .append(coordinatingNode) + .append(totalSizeInBytes) + .append(initProgress) + .append(totalEntities) + .append(activeEntities) + .append(taskProfile) + .append(modelCount) + .toHashCode(); + } + + @Override + public String toString() { + ToStringBuilder toStringBuilder = new ToStringBuilder(this); + + if (state != null) { + toStringBuilder.append(CommonName.STATE, state); + } + if (error != null) { + toStringBuilder.append(CommonName.ERROR, error); + } + if (modelProfile != null && modelProfile.length > 0) { + toStringBuilder.append(modelProfile); + } + if (shingleSize != -1) { + toStringBuilder.append(CommonName.SHINGLE_SIZE, shingleSize); + } + if (coordinatingNode != null) { + toStringBuilder.append(CommonName.COORDINATING_NODE, coordinatingNode); + } + if (totalSizeInBytes != -1) { + toStringBuilder.append(CommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); + } + if (initProgress != null) { + toStringBuilder.append(CommonName.INIT_PROGRESS, initProgress); + } + if (totalEntities != null) { + toStringBuilder.append(CommonName.TOTAL_ENTITIES, totalEntities); + } + if (activeEntities != null) { + toStringBuilder.append(CommonName.ACTIVE_ENTITIES, activeEntities); + } + if (taskProfile != null) { + toStringBuilder.append(getTaskFieldName(), taskProfile); + } + if (modelCount > 0) { + toStringBuilder.append(CommonName.MODEL_COUNT, modelCount); + } + return toStringBuilder.toString(); + } + + protected abstract TaskProfileType createTaskProfile(StreamInput in) throws IOException; + + protected abstract String getTaskFieldName(); +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorState.java b/src/main/java/org/opensearch/timeseries/model/ConfigState.java similarity index 83% rename from src/main/java/org/opensearch/ad/model/DetectorState.java rename to src/main/java/org/opensearch/timeseries/model/ConfigState.java index a4959417b..4af52f2ee 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorState.java +++ b/src/main/java/org/opensearch/timeseries/model/ConfigState.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; -public enum DetectorState { +public enum ConfigState { DISABLED, INIT, RUNNING diff --git a/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java b/src/main/java/org/opensearch/timeseries/model/ConfigValidationIssue.java similarity index 87% rename from src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java rename to src/main/java/org/opensearch/timeseries/model/ConfigValidationIssue.java index 48586e7f8..895c070ec 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java +++ b/src/main/java/org/opensearch/timeseries/model/ConfigValidationIssue.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; import java.util.Map; @@ -19,14 +19,11 @@ import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.ValidationAspect; -import org.opensearch.timeseries.model.ValidationIssueType; import com.google.common.base.Objects; /** - * DetectorValidationIssue is a single validation issue found for detector. + * ConfigValidationIssue is a single validation issue found for config. * * For example, if detector's multiple features are using wrong type field or non existing field * the issue would be in `detector` aspect, not `model`; @@ -35,7 +32,7 @@ * subIssues are issues for each feature; * suggestion is how to fix the issue/subIssues found */ -public class DetectorValidationIssue implements ToXContentObject, Writeable { +public class ConfigValidationIssue implements ToXContentObject, Writeable { private static final String MESSAGE_FIELD = "message"; private static final String SUGGESTED_FIELD_NAME = "suggested_value"; private static final String SUB_ISSUES_FIELD_NAME = "sub_issues"; @@ -66,7 +63,7 @@ public IntervalTimeConfiguration getIntervalSuggestion() { return intervalSuggestion; } - public DetectorValidationIssue( + public ConfigValidationIssue( ValidationAspect aspect, ValidationIssueType type, String message, @@ -80,11 +77,11 @@ public DetectorValidationIssue( this.intervalSuggestion = intervalSuggestion; } - public DetectorValidationIssue(ValidationAspect aspect, ValidationIssueType type, String message) { + public ConfigValidationIssue(ValidationAspect aspect, ValidationIssueType type, String message) { this(aspect, type, message, null, null); } - public DetectorValidationIssue(StreamInput input) throws IOException { + public ConfigValidationIssue(StreamInput input) throws IOException { aspect = input.readEnum(ValidationAspect.class); type = input.readEnum(ValidationIssueType.class); message = input.readString(); @@ -139,7 +136,7 @@ public boolean equals(Object o) { return true; if (o == null || getClass() != o.getClass()) return false; - DetectorValidationIssue anotherIssue = (DetectorValidationIssue) o; + ConfigValidationIssue anotherIssue = (ConfigValidationIssue) o; return Objects.equal(getAspect(), anotherIssue.getAspect()) && Objects.equal(getMessage(), anotherIssue.getMessage()) && Objects.equal(getSubIssues(), anotherIssue.getSubIssues()) diff --git a/src/main/java/org/opensearch/timeseries/model/Entity.java b/src/main/java/org/opensearch/timeseries/model/Entity.java index f05f5dc2a..8fc6f77c9 100644 --- a/src/main/java/org/opensearch/timeseries/model/Entity.java +++ b/src/main/java/org/opensearch/timeseries/model/Entity.java @@ -24,6 +24,7 @@ import java.util.SortedMap; import java.util.TreeMap; +import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.util.SetOnce; import org.opensearch.common.Numbers; import org.opensearch.common.hash.MurmurHash3; @@ -38,6 +39,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParser.Token; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.timeseries.annotation.Generated; import org.opensearch.timeseries.constant.CommonName; @@ -339,9 +343,11 @@ public Map getAttributes() { } } * + * Used to query customer index + * *@return a list of term query builder */ - public List getTermQueryBuilders() { + public List getTermQueryForCustomerIndex() { List res = new ArrayList<>(); for (Map.Entry attribute : attributes.entrySet()) { res.add(new TermQueryBuilder(attribute.getKey(), attribute.getValue())); @@ -349,7 +355,7 @@ public List getTermQueryBuilders() { return res; } - public List getTermQueryBuilders(String pathPrefix) { + public List getTermQueryForCustomerIndex(String pathPrefix) { List res = new ArrayList<>(); for (Map.Entry attribute : attributes.entrySet()) { res.add(new TermQueryBuilder(pathPrefix + attribute.getKey(), attribute.getValue())); @@ -357,6 +363,62 @@ public List getTermQueryBuilders(String pathPrefix) { return res; } + /** + * Used to query result index. + * + * @return a list of term queries to locate documents containing the entity + */ + public List getTermQueryForResultIndex() { + String path = "entity"; + String entityName = path + ".name"; + String entityValue = path + ".value"; + + List res = new ArrayList<>(); + + for (Map.Entry attribute : attributes.entrySet()) { + /* + * each attribute pair corresponds to a nested query like + "nested": { + "query": { + "bool": { + "filter": [ + { + "term": { + "entity.name": { + "value": "turkey4", + "boost": 1 + } + } + }, + { + "term": { + "entity.value": { + "value": "Turkey", + "boost": 1 + } + } + } + ] + } + }, + "path": "entity", + "ignore_unmapped": false, + "score_mode": "none", + "boost": 1 + } + },*/ + BoolQueryBuilder nestedBoolQueryBuilder = new BoolQueryBuilder(); + + TermQueryBuilder entityNameFilterQuery = QueryBuilders.termQuery(entityName, attribute.getKey()); + nestedBoolQueryBuilder.filter(entityNameFilterQuery); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValue, attribute.getValue()); + nestedBoolQueryBuilder.filter(entityValueFilterQuery); + + res.add(new NestedQueryBuilder(path, nestedBoolQueryBuilder, ScoreMode.None)); + } + return res; + } + /** * From json to Entity instance * @param entityValue json array consisting attributes diff --git a/src/main/java/org/opensearch/ad/model/EntityProfile.java b/src/main/java/org/opensearch/timeseries/model/EntityProfile.java similarity index 95% rename from src/main/java/org/opensearch/ad/model/EntityProfile.java rename to src/main/java/org/opensearch/timeseries/model/EntityProfile.java index 4f2306e96..9eee3c8a9 100644 --- a/src/main/java/org/opensearch/ad/model/EntityProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; import java.util.Optional; @@ -17,12 +17,12 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; /** * Profile output for detector entity. @@ -168,13 +168,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(LAST_SAMPLE_TIMESTAMP, lastSampleTimestampMs); } if (initProgress != null) { - builder.field(ADCommonName.INIT_PROGRESS, initProgress); + builder.field(CommonName.INIT_PROGRESS, initProgress); } if (modelProfile != null) { - builder.field(ADCommonName.MODEL, modelProfile); + builder.field(CommonName.MODEL, modelProfile); } if (state != null && state != EntityState.UNKNOWN) { - builder.field(ADCommonName.STATE, state); + builder.field(CommonName.STATE, state); } builder.endObject(); return builder; @@ -213,13 +213,13 @@ public String toString() { builder.append(LAST_SAMPLE_TIMESTAMP, lastSampleTimestampMs); } if (initProgress != null) { - builder.append(ADCommonName.INIT_PROGRESS, initProgress); + builder.append(CommonName.INIT_PROGRESS, initProgress); } if (modelProfile != null) { - builder.append(ADCommonName.MODELS, modelProfile); + builder.append(CommonName.MODELS, modelProfile); } if (state != null && state != EntityState.UNKNOWN) { - builder.append(ADCommonName.STATE, state); + builder.append(CommonName.STATE, state); } return builder.toString(); } diff --git a/src/main/java/org/opensearch/ad/model/EntityProfileName.java b/src/main/java/org/opensearch/timeseries/model/EntityProfileName.java similarity index 75% rename from src/main/java/org/opensearch/ad/model/EntityProfileName.java rename to src/main/java/org/opensearch/timeseries/model/EntityProfileName.java index 84fd92987..c32636d5f 100644 --- a/src/main/java/org/opensearch/ad/model/EntityProfileName.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityProfileName.java @@ -9,20 +9,20 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.util.Collection; import java.util.Set; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.constant.CommonName; public enum EntityProfileName implements Name { - INIT_PROGRESS(ADCommonName.INIT_PROGRESS), - ENTITY_INFO(ADCommonName.ENTITY_INFO), - STATE(ADCommonName.STATE), - MODELS(ADCommonName.MODELS); + INIT_PROGRESS(CommonName.INIT_PROGRESS), + ENTITY_INFO(CommonName.ENTITY_INFO), + STATE(CommonName.STATE), + MODELS(CommonName.MODELS); private String name; @@ -42,13 +42,13 @@ public String getName() { public static EntityProfileName getName(String name) { switch (name) { - case ADCommonName.INIT_PROGRESS: + case CommonName.INIT_PROGRESS: return INIT_PROGRESS; - case ADCommonName.ENTITY_INFO: + case CommonName.ENTITY_INFO: return ENTITY_INFO; - case ADCommonName.STATE: + case CommonName.STATE: return STATE; - case ADCommonName.MODELS: + case CommonName.MODELS: return MODELS; default: throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); diff --git a/src/main/java/org/opensearch/ad/model/EntityState.java b/src/main/java/org/opensearch/timeseries/model/EntityState.java similarity index 89% rename from src/main/java/org/opensearch/ad/model/EntityState.java rename to src/main/java/org/opensearch/timeseries/model/EntityState.java index 1e0d05d8e..36ab0fc0e 100644 --- a/src/main/java/org/opensearch/ad/model/EntityState.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityState.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; public enum EntityState { UNKNOWN, diff --git a/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java b/src/main/java/org/opensearch/timeseries/model/EntityTaskProfile.java similarity index 90% rename from src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java rename to src/main/java/org/opensearch/timeseries/model/EntityTaskProfile.java index 3d473d0e2..5971af22f 100644 --- a/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/EntityTaskProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -22,12 +22,11 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.timeseries.model.Entity; /** - * HC detector's entity task profile. + * HC analysis's entity task profile. */ -public class ADEntityTaskProfile implements ToXContentObject, Writeable { +public class EntityTaskProfile implements ToXContentObject, Writeable { public static final String SHINGLE_SIZE_FIELD = "shingle_size"; public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; @@ -37,7 +36,7 @@ public class ADEntityTaskProfile implements ToXContentObject, Writeable { public static final String NODE_ID_FIELD = "node_id"; public static final String ENTITY_FIELD = "entity"; public static final String TASK_ID_FIELD = "task_id"; - public static final String AD_TASK_TYPE_FIELD = "task_type"; + public static final String TASK_TYPE_FIELD = "task_type"; private Integer shingleSize; private Long rcfTotalUpdates; @@ -47,9 +46,9 @@ public class ADEntityTaskProfile implements ToXContentObject, Writeable { private String nodeId; private Entity entity; private String taskId; - private String adTaskType; + private String taskType; - public ADEntityTaskProfile( + public EntityTaskProfile( Integer shingleSize, Long rcfTotalUpdates, Boolean thresholdModelTrained, @@ -68,10 +67,10 @@ public ADEntityTaskProfile( this.nodeId = nodeId; this.entity = entity; this.taskId = taskId; - this.adTaskType = adTaskType; + this.taskType = adTaskType; } - public static ADEntityTaskProfile parse(XContentParser parser) throws IOException { + public static EntityTaskProfile parse(XContentParser parser) throws IOException { Integer shingleSize = null; Long rcfTotalUpdates = null; Boolean thresholdModelTrained = null; @@ -112,7 +111,7 @@ public static ADEntityTaskProfile parse(XContentParser parser) throws IOExceptio case TASK_ID_FIELD: taskId = parser.text(); break; - case AD_TASK_TYPE_FIELD: + case TASK_TYPE_FIELD: taskType = parser.text(); break; default: @@ -120,7 +119,7 @@ public static ADEntityTaskProfile parse(XContentParser parser) throws IOExceptio break; } } - return new ADEntityTaskProfile( + return new EntityTaskProfile( shingleSize, rcfTotalUpdates, thresholdModelTrained, @@ -133,7 +132,7 @@ public static ADEntityTaskProfile parse(XContentParser parser) throws IOExceptio ); } - public ADEntityTaskProfile(StreamInput input) throws IOException { + public EntityTaskProfile(StreamInput input) throws IOException { this.shingleSize = input.readOptionalInt(); this.rcfTotalUpdates = input.readOptionalLong(); this.thresholdModelTrained = input.readOptionalBoolean(); @@ -146,7 +145,7 @@ public ADEntityTaskProfile(StreamInput input) throws IOException { this.entity = null; } this.taskId = input.readOptionalString(); - this.adTaskType = input.readOptionalString(); + this.taskType = input.readOptionalString(); } @Override @@ -164,7 +163,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalString(taskId); - out.writeOptionalString(adTaskType); + out.writeOptionalString(taskType); } @Override @@ -194,8 +193,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (taskId != null) { xContentBuilder.field(TASK_ID_FIELD, taskId); } - if (adTaskType != null) { - xContentBuilder.field(AD_TASK_TYPE_FIELD, adTaskType); + if (taskType != null) { + xContentBuilder.field(TASK_TYPE_FIELD, taskType); } return xContentBuilder.endObject(); } @@ -265,11 +264,11 @@ public void setTaskId(String taskId) { } public String getAdTaskType() { - return adTaskType; + return taskType; } public void setAdTaskType(String adTaskType) { - this.adTaskType = adTaskType; + this.taskType = adTaskType; } @Override @@ -278,7 +277,7 @@ public boolean equals(Object o) { return true; if (o == null || getClass() != o.getClass()) return false; - ADEntityTaskProfile that = (ADEntityTaskProfile) o; + EntityTaskProfile that = (EntityTaskProfile) o; return Objects.equals(shingleSize, that.shingleSize) && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) && Objects.equals(thresholdModelTrained, that.thresholdModelTrained) @@ -286,7 +285,7 @@ public boolean equals(Object o) { && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) && Objects.equals(nodeId, that.nodeId) && Objects.equals(taskId, that.taskId) - && Objects.equals(adTaskType, that.adTaskType) + && Objects.equals(taskType, that.taskType) && Objects.equals(entity, that.entity); } @@ -302,7 +301,7 @@ public int hashCode() { nodeId, entity, taskId, - adTaskType + taskType ); } } diff --git a/src/main/java/org/opensearch/timeseries/model/IndexableResult.java b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java index 7ccc58b59..0393122bd 100644 --- a/src/main/java/org/opensearch/timeseries/model/IndexableResult.java +++ b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java @@ -39,17 +39,6 @@ public abstract class IndexableResult implements Writeable, ToXContentObject { protected final Optional optionalEntity; protected User user; protected final Integer schemaVersion; - /* - * model id for easy aggregations of entities. The front end needs to query - * for entities ordered by the descending/ascending order of feature values. - * After supporting multi-category fields, it is hard to write such queries - * since the entity information is stored in a nested object array. - * Also, the front end has all code/queries/ helper functions in place to - * rely on a single key per entity combo. Adding model id to forecast result - * to help the transition to multi-categorical field less painful. - */ - protected final String modelId; - protected final String entityId; protected final String taskId; public IndexableResult( @@ -63,7 +52,6 @@ public IndexableResult( Optional entity, User user, Integer schemaVersion, - String modelId, String taskId ) { this.configId = configId; @@ -76,9 +64,7 @@ public IndexableResult( this.optionalEntity = entity; this.user = user; this.schemaVersion = schemaVersion; - this.modelId = modelId; this.taskId = taskId; - this.entityId = getEntityId(entity, configId); } public IndexableResult(StreamInput input) throws IOException { @@ -104,9 +90,7 @@ public IndexableResult(StreamInput input) throws IOException { user = null; } this.schemaVersion = input.readInt(); - this.modelId = input.readOptionalString(); this.taskId = input.readOptionalString(); - this.entityId = input.readOptionalString(); } @Override @@ -134,9 +118,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); // user does not exist } out.writeInt(schemaVersion); - out.writeOptionalString(modelId); out.writeOptionalString(taskId); - out.writeOptionalString(entityId); } public String getConfigId() { @@ -171,18 +153,10 @@ public Optional getEntity() { return optionalEntity; } - public String getModelId() { - return modelId; - } - public String getTaskId() { return taskId; } - public String getEntityId() { - return entityId; - } - /** * entityId equals to model Id. It is hard to explain to users what * modelId is. entityId is more user friendly. @@ -209,9 +183,7 @@ public boolean equals(Object o) { && Objects.equal(executionStartTime, that.executionStartTime) && Objects.equal(executionEndTime, that.executionEndTime) && Objects.equal(error, that.error) - && Objects.equal(optionalEntity, that.optionalEntity) - && Objects.equal(modelId, that.modelId) - && Objects.equal(entityId, that.entityId); + && Objects.equal(optionalEntity, that.optionalEntity); } @Generated @@ -227,9 +199,7 @@ public int hashCode() { executionStartTime, executionEndTime, error, - optionalEntity, - modelId, - entityId + optionalEntity ); } @@ -245,8 +215,6 @@ public String toString() { .append("executionEndTime", executionEndTime) .append("error", error) .append("entity", optionalEntity) - .append("modelId", modelId) - .append("entityId", entityId) .toString(); } diff --git a/src/main/java/org/opensearch/ad/model/InitProgressProfile.java b/src/main/java/org/opensearch/timeseries/model/InitProgressProfile.java similarity index 99% rename from src/main/java/org/opensearch/ad/model/InitProgressProfile.java rename to src/main/java/org/opensearch/timeseries/model/InitProgressProfile.java index 4147f8ef4..1b2a83f4c 100644 --- a/src/main/java/org/opensearch/ad/model/InitProgressProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/InitProgressProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; diff --git a/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java index eaa6301df..22c0fb416 100644 --- a/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java +++ b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java @@ -103,6 +103,12 @@ public int hashCode() { return Objects.hashCode(interval, unit); } + @Generated + @Override + public String toString() { + return "IntervalTimeConfiguration [interval=" + interval + ", unit=" + unit + "]"; + } + public long getInterval() { return interval; } @@ -119,4 +125,13 @@ public ChronoUnit getUnit() { public Duration toDuration() { return Duration.of(interval, unit); } + + /** + * + * @param other interval to compare + * @return current interval is larger than or equal to the given interval + */ + public boolean gte(IntervalTimeConfiguration other) { + return toDuration().compareTo(other.toDuration()) >= 0; + } } diff --git a/src/main/java/org/opensearch/timeseries/model/Job.java b/src/main/java/org/opensearch/timeseries/model/Job.java index 958152e2c..d258279e7 100644 --- a/src/main/java/org/opensearch/timeseries/model/Job.java +++ b/src/main/java/org/opensearch/timeseries/model/Job.java @@ -12,13 +12,13 @@ package org.opensearch.timeseries.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.DEFAULT_JOB_LOC_DURATION_SECONDS; import java.io.IOException; import java.time.Instant; import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -31,6 +31,8 @@ import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; import org.opensearch.jobscheduler.spi.schedule.Schedule; import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ParseUtils; import com.google.common.base.Objects; @@ -44,7 +46,7 @@ enum ScheduleType { INTERVAL } - public static final String PARSE_FIELD_NAME = "AnomalyDetectorJob"; + public static final String PARSE_FIELD_NAME = "TimeSeriesJob"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( Job.class, new ParseField(PARSE_FIELD_NAME), @@ -62,7 +64,9 @@ enum ScheduleType { public static final String DISABLED_TIME_FIELD = "disabled_time"; public static final String USER_FIELD = "user"; private static final String RESULT_INDEX_FIELD = "result_index"; + private static final String TYPE_FIELD = "type"; + // name is config id private final String name; private final Schedule schedule; private final TimeConfiguration windowDelay; @@ -73,6 +77,7 @@ enum ScheduleType { private final Long lockDurationSeconds; private final User user; private String resultIndex; + private AnalysisType analysisType; public Job( String name, @@ -84,7 +89,8 @@ public Job( Instant lastUpdateTime, Long lockDurationSeconds, User user, - String resultIndex + String resultIndex, + AnalysisType type ) { this.name = name; this.schedule = schedule; @@ -96,6 +102,7 @@ public Job( this.lockDurationSeconds = lockDurationSeconds; this.user = user; this.resultIndex = resultIndex; + this.analysisType = type; } public Job(StreamInput input) throws IOException { @@ -117,6 +124,8 @@ public Job(StreamInput input) throws IOException { user = null; } resultIndex = input.readOptionalString(); + String typeStr = input.readOptionalString(); + this.analysisType = input.readEnum(AnalysisType.class); } @Override @@ -129,7 +138,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(IS_ENABLED_FIELD, isEnabled) .field(ENABLED_TIME_FIELD, enabledTime.toEpochMilli()) .field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()) - .field(LOCK_DURATION_SECONDS, lockDurationSeconds); + .field(LOCK_DURATION_SECONDS, lockDurationSeconds) + .field(TYPE_FIELD, analysisType); if (disabledTime != null) { xContentBuilder.field(DISABLED_TIME_FIELD, disabledTime.toEpochMilli()); } @@ -164,6 +174,7 @@ public void writeTo(StreamOutput output) throws IOException { output.writeBoolean(false); // user does not exist } output.writeOptionalString(resultIndex); + output.writeEnum(analysisType); } public static Job parse(XContentParser parser) throws IOException { @@ -175,9 +186,10 @@ public static Job parse(XContentParser parser) throws IOException { Instant enabledTime = null; Instant disabledTime = null; Instant lastUpdateTime = null; - Long lockDurationSeconds = DEFAULT_JOB_LOC_DURATION_SECONDS; + Long lockDurationSeconds = TimeSeriesSettings.DEFAULT_JOB_LOC_DURATION_SECONDS; User user = null; String resultIndex = null; + String analysisType = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -215,6 +227,9 @@ public static Job parse(XContentParser parser) throws IOException { case RESULT_INDEX_FIELD: resultIndex = parser.text(); break; + case TYPE_FIELD: + analysisType = parser.text(); + break; default: parser.skipChildren(); break; @@ -230,7 +245,10 @@ public static Job parse(XContentParser parser) throws IOException { lastUpdateTime, lockDurationSeconds, user, - resultIndex + resultIndex, + (Strings.isEmpty(analysisType) || AnalysisType.AD == AnalysisType.valueOf(analysisType)) + ? AnalysisType.AD + : AnalysisType.FORECAST ); } @@ -248,12 +266,13 @@ public boolean equals(Object o) { && Objects.equal(getDisabledTime(), that.getDisabledTime()) && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) && Objects.equal(getLockDurationSeconds(), that.getLockDurationSeconds()) - && Objects.equal(getCustomResultIndex(), that.getCustomResultIndex()); + && Objects.equal(getCustomResultIndex(), that.getCustomResultIndex()) + && Objects.equal(getAnalysisType(), that.getAnalysisType()); } @Override public int hashCode() { - return Objects.hashCode(name, schedule, isEnabled, enabledTime, lastUpdateTime); + return Objects.hashCode(name, schedule, isEnabled, enabledTime, lastUpdateTime, analysisType); } @Override @@ -301,4 +320,8 @@ public User getUser() { public String getCustomResultIndex() { return resultIndex; } + + public AnalysisType getAnalysisType() { + return analysisType; + } } diff --git a/src/main/java/org/opensearch/ad/model/Mergeable.java b/src/main/java/org/opensearch/timeseries/model/Mergeable.java similarity index 89% rename from src/main/java/org/opensearch/ad/model/Mergeable.java rename to src/main/java/org/opensearch/timeseries/model/Mergeable.java index 980dad1a4..bdb9ef49e 100644 --- a/src/main/java/org/opensearch/ad/model/Mergeable.java +++ b/src/main/java/org/opensearch/timeseries/model/Mergeable.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; public interface Mergeable { void merge(Mergeable other); diff --git a/src/main/java/org/opensearch/timeseries/model/MergeableList.java b/src/main/java/org/opensearch/timeseries/model/MergeableList.java index fd9f26e84..188c0fa44 100644 --- a/src/main/java/org/opensearch/timeseries/model/MergeableList.java +++ b/src/main/java/org/opensearch/timeseries/model/MergeableList.java @@ -13,8 +13,6 @@ import java.util.List; -import org.opensearch.ad.model.Mergeable; - public class MergeableList implements Mergeable { private final List elements; diff --git a/src/main/java/org/opensearch/ad/model/ModelProfile.java b/src/main/java/org/opensearch/timeseries/model/ModelProfile.java similarity index 97% rename from src/main/java/org/opensearch/ad/model/ModelProfile.java rename to src/main/java/org/opensearch/timeseries/model/ModelProfile.java index 1d6d0ce85..63fdbcd02 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/ModelProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; @@ -22,7 +22,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; /** * Used to show model information in profile API diff --git a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java b/src/main/java/org/opensearch/timeseries/model/ModelProfileOnNode.java similarity index 95% rename from src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java rename to src/main/java/org/opensearch/timeseries/model/ModelProfileOnNode.java index 1e45bcc7a..ed61342c8 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java +++ b/src/main/java/org/opensearch/timeseries/model/ModelProfileOnNode.java @@ -9,19 +9,19 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; public class ModelProfileOnNode implements Writeable, ToXContent { // field name in toXContent @@ -98,7 +98,7 @@ public int hashCode() { @Override public String toString() { ToStringBuilder builder = new ToStringBuilder(this); - builder.append(ADCommonName.MODEL, modelProfile); + builder.append(CommonName.MODEL, modelProfile); builder.append(NODE_ID, nodeId); return builder.toString(); } diff --git a/src/main/java/org/opensearch/timeseries/model/ProfileName.java b/src/main/java/org/opensearch/timeseries/model/ProfileName.java new file mode 100644 index 000000000..dd3f1bac9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/ProfileName.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import java.util.Collection; +import java.util.Set; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.constant.CommonName; + +public enum ProfileName implements Name { + STATE(CommonName.STATE), + ERROR(CommonName.ERROR), + COORDINATING_NODE(CommonName.COORDINATING_NODE), + SHINGLE_SIZE(CommonName.SHINGLE_SIZE), + TOTAL_SIZE_IN_BYTES(CommonName.TOTAL_SIZE_IN_BYTES), + MODELS(CommonName.MODELS), + INIT_PROGRESS(CommonName.INIT_PROGRESS), + TOTAL_ENTITIES(CommonName.TOTAL_ENTITIES), + ACTIVE_ENTITIES(CommonName.ACTIVE_ENTITIES), + // AD only + AD_TASK(ADCommonName.AD_TASK), + // Forecast only + FORECAST_TASK(ForecastCommonName.FORECAST_TASK); + + private String name; + + ProfileName(String name) { + this.name = name; + } + + /** + * Get profile name + * + * @return name + */ + @Override + public String getName() { + return name; + } + + public static ProfileName getName(String name) { + switch (name) { + case CommonName.STATE: + return STATE; + case CommonName.ERROR: + return ERROR; + case CommonName.COORDINATING_NODE: + return COORDINATING_NODE; + case CommonName.SHINGLE_SIZE: + return SHINGLE_SIZE; + case CommonName.TOTAL_SIZE_IN_BYTES: + return TOTAL_SIZE_IN_BYTES; + case CommonName.MODELS: + return MODELS; + case CommonName.INIT_PROGRESS: + return INIT_PROGRESS; + case CommonName.TOTAL_ENTITIES: + return TOTAL_ENTITIES; + case CommonName.ACTIVE_ENTITIES: + return ACTIVE_ENTITIES; + case ADCommonName.AD_TASK: + return AD_TASK; + case ForecastCommonName.FORECAST_TASK: + return FORECAST_TASK; + default: + throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); + } + } + + public static Set getNames(Collection names) { + return Name.getNameFromCollection(names, ProfileName::getName); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/TaskState.java b/src/main/java/org/opensearch/timeseries/model/TaskState.java index 2b5c4240e..6f845d49a 100644 --- a/src/main/java/org/opensearch/timeseries/model/TaskState.java +++ b/src/main/java/org/opensearch/timeseries/model/TaskState.java @@ -50,13 +50,32 @@ * */ public enum TaskState { - CREATED, - INIT, - RUNNING, - FAILED, - STOPPED, - FINISHED; + // AD task state + CREATED("Created"), + INIT("Init"), + RUNNING("Running"), + FAILED("Failed"), + STOPPED("Stopped"), + FINISHED("Finished"), + + // Forecast task state + INIT_TEST("Initializing test"), + TEST_COMPLETE("Test complete"), + INIT_TEST_FAILED("Initializing test failed"), + INACTIVE("Inactive"); + + private final String description; + + // Constructor + TaskState(String description) { + this.description = description; + } + + // Getter + public String getDescription() { + return description; + } public static List NOT_ENDED_STATES = ImmutableList - .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name()); + .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name(), INIT_TEST.name()); } diff --git a/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java index fd57de7cd..0384f1356 100644 --- a/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java +++ b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java @@ -41,6 +41,8 @@ public abstract class TimeSeriesTask implements ToXContentObject, Writeable { public static final String ESTIMATED_MINUTES_LEFT_FIELD = "estimated_minutes_left"; public static final String USER_FIELD = "user"; public static final String HISTORICAL_TASK_PREFIX = "HISTORICAL"; + public static final String RUN_ONCE_TASK_PREFIX = "RUN_ONCE"; + public static final String REAL_TIME_TASK_PREFIX = "REALTIME"; protected String configId = null; protected String taskId = null; @@ -200,6 +202,14 @@ public boolean isHistoricalTask() { return taskType.startsWith(TimeSeriesTask.HISTORICAL_TASK_PREFIX); } + public boolean isRunOnceTask() { + return taskType.startsWith(TimeSeriesTask.RUN_ONCE_TASK_PREFIX); + } + + public boolean isRealTimeTask() { + return taskType.startsWith(TimeSeriesTask.REAL_TIME_TASK_PREFIX); + } + /** * Get config level task id. If a task has no parent task, the task is config level task. * @return config level task id @@ -440,7 +450,7 @@ public int hashCode() { ); } - public abstract boolean isEntityTask(); + public abstract boolean isHistoricalEntityTask(); public String getEntityModelId() { return entity == null ? null : entity.getModelId(configId).orElse(null); diff --git a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java index 01913a9c6..717a4f42d 100644 --- a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java +++ b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java @@ -13,6 +13,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.transport.SearchTopForecastResultRequest; import org.opensearch.timeseries.Name; public enum ValidationIssueType implements Name { @@ -32,7 +33,10 @@ public enum ValidationIssueType implements Name { IMPUTATION(Config.IMPUTATION_OPTION_FIELD), DETECTION_INTERVAL(AnomalyDetector.DETECTION_INTERVAL_FIELD), FORECAST_INTERVAL(Forecaster.FORECAST_INTERVAL_FIELD), - HORIZON_SIZE(Forecaster.HORIZON_FIELD); + HORIZON_SIZE(Forecaster.HORIZON_FIELD), + SUBAGGREGATION(SearchTopForecastResultRequest.SUBAGGREGATIONS_FIELD), + TRANSFORM_DECAY(Config.TRANSFORM_DECAY_FIELD), + DESCRIPTION(Config.DESCRIPTION_FIELD); private String name; diff --git a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java similarity index 91% rename from src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java index 7ba8b4383..41f62e243 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -24,8 +24,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; /** @@ -46,8 +46,9 @@ public BatchWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -58,7 +59,8 @@ public BatchWorker( Duration executionTtl, Setting batchSizeSetting, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager timeSeriesNodeStateManager, + AnalysisType context ) { super( queueName, @@ -67,8 +69,9 @@ public BatchWorker( maxHeapPercentForQueueSetting, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -78,7 +81,8 @@ public BatchWorker( concurrencySetting, executionTtl, stateTtl, - nodeStateManager + timeSeriesNodeStateManager, + context ); this.batchSize = batchSizeSetting.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(batchSizeSetting, it -> batchSize = it); @@ -111,7 +115,7 @@ protected void execute(Runnable afterProcessCallback, Runnable emptyQueueCallbac ThreadedActionListener listener = new ThreadedActionListener<>( LOG, threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + threadPoolName, getResponseListener(toProcess, batchRequest), false ); diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java similarity index 71% rename from src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java index 91382a4b5..788d2e370 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -21,35 +21,42 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.util.DateUtils; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.util.DateUtils; -public class CheckPointMaintainRequestAdapter { +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Convert from ModelRequest to CheckpointWriteRequest + * + */ +public class CheckPointMaintainRequestAdapter & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CacheType extends TimeSeriesCache> { private static final Logger LOG = LogManager.getLogger(CheckPointMaintainRequestAdapter.class); - private CacheProvider cache; - private CheckpointDao checkpointDao; + private CheckpointDaoType checkpointDao; private String indexName; private Duration checkpointInterval; private Clock clock; + private Provider cache; public CheckPointMaintainRequestAdapter( - CacheProvider cache, - CheckpointDao checkpointDao, + CheckpointDaoType checkpointDao, String indexName, Setting checkpointIntervalSetting, Clock clock, ClusterService clusterService, - Settings settings + Settings settings, + Provider cache ) { - this.cache = cache; this.checkpointDao = checkpointDao; this.indexName = indexName; @@ -59,15 +66,16 @@ public CheckPointMaintainRequestAdapter( .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); this.clock = clock; + this.cache = cache; } public Optional convert(CheckpointMaintainRequest request) { - String detectorId = request.getId(); - String modelId = request.getEntityModelId(); + String configId = request.getConfigId(); + String modelId = request.getModelId(); - Optional> stateToMaintain = cache.get().getForMaintainance(detectorId, modelId); - if (!stateToMaintain.isEmpty()) { - ModelState state = stateToMaintain.get(); + Optional> stateToMaintain = cache.get().getForMaintainance(configId, modelId); + if (stateToMaintain.isPresent()) { + ModelState state = stateToMaintain.get(); Instant instant = state.getLastCheckpointTime(); if (!checkpointDao.shouldSave(instant, false, checkpointInterval, clock)) { return Optional.empty(); @@ -85,7 +93,7 @@ public Optional convert(CheckpointMaintainRequest reques .of( new CheckpointWriteRequest( request.getExpirationEpochMs(), - detectorId, + configId, request.getPriority(), // If the document does not already exist, the contents of the upsert element // are inserted as a new document. diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java similarity index 58% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java index 28fdfcc91..479965240 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java @@ -9,17 +9,17 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public class CheckpointMaintainRequest extends QueuedRequest { - private String entityModelId; + private String modelId; - public CheckpointMaintainRequest(long expirationEpochMs, String detectorId, RequestPriority priority, String entityModelId) { - super(expirationEpochMs, detectorId, priority); - this.entityModelId = entityModelId; + public CheckpointMaintainRequest(long expirationEpochMs, String configId, RequestPriority priority, String entityModelId) { + super(expirationEpochMs, configId, priority); + this.modelId = entityModelId; } - public String getEntityModelId() { - return entityModelId; + public String getModelId() { + return modelId; } } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java new file mode 100644 index 000000000..ba28043c9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.function.Function; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; + +public abstract class CheckpointMaintainWorker extends ScheduledWorker { + + private Function> converter; + + public CheckpointMaintainWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + String threadPoolName, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + RateLimitedRequestWorker targetQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Function> converter, + AnalysisType context + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + threadPoolName, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + targetQueue, + stateTtl, + nodeStateManager, + context + ); + this.converter = converter; + } + + @Override + protected List transformRequests(List requests) { + List allRequests = new ArrayList<>(); + for (CheckpointMaintainRequest request : requests) { + Optional converted = converter.apply(request); + if (!converted.isEmpty()) { + allRequests.add(converted.get()); + } + } + return allRequests; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java similarity index 59% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java index d4f1f99af..d6c13cf19 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java @@ -9,10 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -21,7 +18,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.Random; import java.util.Set; @@ -32,19 +28,8 @@ import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetRequest; import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndex; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.stats.ADStats; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -53,39 +38,44 @@ import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; import org.opensearch.timeseries.util.ExceptionUtil; -import org.opensearch.timeseries.util.ParseUtils; - -/** - * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: - * a). If a checkpoint is not found, we forward that request to the cold start queue. - * b). When a request gets errors, the queue does not change its expiry time and puts - * that request to the end of the queue and automatically retries them before they expire. - * c) When a checkpoint is found, we load that point to memory and score the input - * data point and save the result if a complete model exists. Otherwise, we enqueue - * the sample. If we can host that model in memory (e.g., there is enough memory), - * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the - * updated checkpoint back to disk. - * - */ -public class CheckpointReadWorker extends BatchWorker { + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class CheckpointReadWorker, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker> + extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class); - public static final String WORKER_NAME = "checkpoint-read"; - private final ModelManager modelManager; - private final CheckpointDao checkpointDao; - private final EntityColdStartWorker entityColdStartQueue; - private final ResultWriteWorker resultWriteQueue; - private final ADIndexManagement indexUtil; - private final CacheProvider cacheProvider; - private final CheckpointWriteWorker checkpointWriteQueue; - private final ADStats adStats; + + protected final ModelManagerType modelManager; + protected final CheckpointType checkpointDao; + protected final ColdStartWorkerType coldStartWorker; + protected final SaveResultStrategyType resultWriteWorker; + protected final IndexManagementType indexUtil; + protected final Stats timeSeriesStats; + protected final CheckpointWriteWorkerType checkpointWriteWorker; + protected final Provider> cacheProvider; + protected final String checkpointIndexName; + protected final StatNames modelCorruptionStat; public CheckpointReadWorker( + String workerName, long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, @@ -93,6 +83,7 @@ public CheckpointReadWorker( Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -100,19 +91,24 @@ public CheckpointReadWorker( float lowSegmentPruneRatio, int maintenanceFreqConstant, Duration executionTtl, - ModelManager modelManager, - CheckpointDao checkpointDao, - EntityColdStartWorker entityColdStartQueue, - ResultWriteWorker resultWriteQueue, + ModelManagerType modelManager, + CheckpointType checkpointDao, + ColdStartWorkerType entityColdStartWorker, NodeStateManager stateManager, - ADIndexManagement indexUtil, - CacheProvider cacheProvider, + IndexManagementType indexUtil, + Provider> cacheProvider, Duration stateTtl, - CheckpointWriteWorker checkpointWriteQueue, - ADStats adStats + CheckpointWriteWorkerType checkpointWriteWorker, + Stats timeSeriesStats, + Setting concurrencySetting, + Setting batchSizeSetting, + String checkpointIndexName, + StatNames modelCorruptionStat, + AnalysisType context, + SaveResultStrategyType resultWriteWorker ) { super( - WORKER_NAME, + workerName, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, @@ -120,27 +116,31 @@ public CheckpointReadWorker( random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + stateManager, + context ); this.modelManager = modelManager; this.checkpointDao = checkpointDao; - this.entityColdStartQueue = entityColdStartQueue; - this.resultWriteQueue = resultWriteQueue; + this.coldStartWorker = entityColdStartWorker; this.indexUtil = indexUtil; this.cacheProvider = cacheProvider; - this.checkpointWriteQueue = checkpointWriteQueue; - this.adStats = adStats; + this.checkpointWriteWorker = checkpointWriteWorker; + this.timeSeriesStats = timeSeriesStats; + this.checkpointIndexName = checkpointIndexName; + this.modelCorruptionStat = modelCorruptionStat; + this.resultWriteWorker = resultWriteWorker; } @Override @@ -149,28 +149,29 @@ protected void executeBatchRequest(MultiGetRequest request, ActionListener toProcess) { + protected MultiGetRequest toBatchRequest(List toProcess) { MultiGetRequest multiGetRequest = new MultiGetRequest(); - for (EntityRequest request : toProcess) { - Optional modelId = request.getModelId(); - if (false == modelId.isPresent()) { + for (FeatureRequest request : toProcess) { + String modelId = request.getModelId(); + if (null == modelId) { continue; } - multiGetRequest.add(new MultiGetRequest.Item(ADCommonName.CHECKPOINT_INDEX_NAME, modelId.get())); + multiGetRequest.add(new MultiGetRequest.Item(checkpointIndexName, modelId)); } return multiGetRequest; } @Override - protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { + protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { return ActionListener.wrap(response -> { + final MultiGetItemResponse[] itemResponses = response.getResponses(); Map successfulRequests = new HashMap<>(); @@ -186,11 +187,11 @@ protected ActionListener getResponseListener(List getResponseListener(List getResponseListener(List modelId = origRequest.getModelId(); - if (modelId.isPresent() && notFoundModels.contains(modelId.get())) { + for (FeatureRequest origRequest : toProcess) { + String modelId = origRequest.getModelId(); + if (modelId != null && notFoundModels.contains(modelId)) { // submit to cold start queue - entityColdStartQueue.put(origRequest); + coldStartWorker.put(origRequest); } } } @@ -241,15 +242,17 @@ protected ActionListener getResponseListener(List modelId = origRequest.getModelId(); - if (modelId.isPresent() && stopDetectorRequests.containsKey(modelId.get())) { - String adID = origRequest.detectorId; + for (FeatureRequest origRequest : toProcess) { + String modelId = origRequest.getModelId(); + if (modelId != null && stopDetectorRequests.containsKey(modelId)) { + String configID = origRequest.getConfigId(); nodeStateManager .setException( - adID, - new EndRunException(adID, CommonMessages.BUG_RESPONSE, stopDetectorRequests.get(modelId.get()), false) + configID, + new EndRunException(configID, CommonMessages.BUG_RESPONSE, stopDetectorRequests.get(modelId), false) ); + // once one EndRunException is set, we can break; no point setting the exception repeatedly + break; } } } @@ -258,11 +261,10 @@ protected ActionListener getResponseListener(List { if (ExceptionUtil.isOverloaded(exception)) { - LOG.error("too many get AD model checkpoint requests or shard not available"); + LOG.error("too many get model checkpoint requests or shard not available"); setCoolDownStart(); } else if (ExceptionUtil.isRetryAble(exception)) { // retry all of them @@ -273,9 +275,9 @@ protected ActionListener getResponseListener(List toProcess, + List toProcess, Map successfulRequests, Set retryableRequests ) { @@ -287,42 +289,41 @@ private void processCheckpointIteration( // if false, finally will process next checkpoints boolean processNextInCallBack = false; try { - EntityFeatureRequest origRequest = toProcess.get(i); + FeatureRequest origRequest = toProcess.get(i); - Optional modelIdOptional = origRequest.getModelId(); - if (false == modelIdOptional.isPresent()) { + String modelId = origRequest.getModelId(); + if (null == modelId) { return; } - String detectorId = origRequest.getId(); - Entity entity = origRequest.getEntity(); - - String modelId = modelIdOptional.get(); + String configId = origRequest.getConfigId(); + Optional entity = origRequest.getEntity(); MultiGetItemResponse checkpointResponse = successfulRequests.get(modelId); if (checkpointResponse != null) { // successful requests - Optional> checkpoint = checkpointDao - .processGetResponse(checkpointResponse.getResponse(), modelId); + ModelState modelState = checkpointDao + .processHCGetResponse(checkpointResponse.getResponse(), modelId, configId); - if (false == checkpoint.isPresent()) { - // checkpoint is too big + if (null == modelState) { + // checkpoint is not available (e.g., too big or corrupted); cold start again + coldStartWorker.put(origRequest); return; } nodeStateManager .getConfig( - detectorId, - AnalysisType.AD, - onGetDetector( + configId, + context, + processIterationUsingConfig( origRequest, i, - detectorId, + configId, toProcess, successfulRequests, retryableRequests, - checkpoint, + modelState, entity, modelId ) @@ -339,39 +340,47 @@ private void processCheckpointIteration( } } - private ActionListener> onGetDetector( - EntityFeatureRequest origRequest, + protected ActionListener> processIterationUsingConfig( + FeatureRequest origRequest, int index, - String detectorId, - List toProcess, + String configId, + List toProcess, Map successfulRequests, Set retryableRequests, - Optional> checkpoint, - Entity entity, + ModelState restoredModelState, + Optional entity, String modelId ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return ActionListener.wrap(configOptional -> { + if (configOptional.isEmpty()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); return; } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); + Config config = configOptional.get(); - ModelState modelState = modelManager - .processEntityCheckpoint(checkpoint, entity, modelId, detectorId, detector.getShingleSize()); - - ThresholdingResult result = null; + RCFResultType result = null; try { result = modelManager - .getAnomalyResultForEntity(origRequest.getCurrentFeature(), modelState, modelId, entity, detector.getShingleSize()); + .getResult( + new Sample( + origRequest.getCurrentFeature(), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()) + ), + restoredModelState, + modelId, + entity, + config, + origRequest.getTaskId() + ); } catch (IllegalArgumentException e) { // fail to score likely due to model corruption. Re-cold start to recover. LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", origRequest.getModelId()), e); - adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); - if (origRequest.getModelId().isPresent()) { - String entityModelId = origRequest.getModelId().get(); + timeSeriesStats.getStat(modelCorruptionStat.getName()).increment(); + if (null != origRequest.getModelId()) { + String entityModelId = origRequest.getModelId(); checkpointDao .deleteModelCheckpoint( entityModelId, @@ -383,56 +392,26 @@ private ActionListener> onGetDetector( ); } - entityColdStartQueue.put(origRequest); + coldStartWorker.put(origRequest); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); return; } - if (result != null && result.getRcfScore() > 0) { - RequestPriority requestPriority = result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM; - - List resultsToSave = result - .toIndexableResults( - detector, - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + detector.getIntervalInMilliseconds()), - Instant.now(), - Instant.now(), - ParseUtils.getFeatureData(origRequest.getCurrentFeature(), detector), - Optional.ofNullable(entity), - indexUtil.getSchemaVersion(ADIndex.RESULT), - modelId, - null, - null - ); - - for (AnomalyResult r : resultsToSave) { - resultWriteQueue - .put( - new ResultWriteRequest( - origRequest.getExpirationEpochMs(), - detectorId, - requestPriority, - r, - detector.getCustomResultIndex() - ) - ); - } - } + resultWriteWorker.saveResult(result, config, origRequest, modelId); // try to load to cache - boolean loaded = cacheProvider.get().hostIfPossible(detector, modelState); + boolean loaded = cacheProvider.get().hostIfPossible(config, restoredModelState); if (false == loaded) { // not in memory. Maybe cold entities or some other entities // have filled the slot while waiting for loading checkpoints. - checkpointWriteQueue.write(modelState, true, RequestPriority.LOW); + checkpointWriteWorker.write(restoredModelState, true, RequestPriority.LOW); } processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); }, exception -> { LOG.error(new ParameterizedMessage("fail to get checkpoint [{}]", modelId, exception)); - nodeStateManager.setException(detectorId, exception); + nodeStateManager.setException(configId, exception); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); }); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java similarity index 94% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java index 9c41e55be..02a374f82 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import org.opensearch.action.update.UpdateRequest; diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java similarity index 72% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java index a26cb8b94..0201ff950 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java @@ -1,18 +1,9 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -30,10 +21,6 @@ import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -43,58 +30,69 @@ import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.util.ExceptionUtil; -public class CheckpointWriteWorker extends BatchWorker { +public abstract class CheckpointWriteWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao> + extends BatchWorker { private static final Logger LOG = LogManager.getLogger(CheckpointWriteWorker.class); - public static final String WORKER_NAME = "checkpoint-write"; - private final CheckpointDao checkpoint; - private final String indexName; - private final Duration checkpointInterval; + protected final CheckpointDaoType checkpoint; + protected final String indexName; + protected final Duration checkpointInterval; public CheckpointWriteWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, + String queueName, + long heapSize, + int singleRequestSize, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, + Setting concurrencySetting, Duration executionTtl, - CheckpointDao checkpoint, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager timeSeriesNodeStateManager, + CheckpointDaoType checkpoint, String indexName, Duration checkpointInterval, - NodeStateManager stateManager, - Duration stateTtl + AnalysisType context ) { super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, + queueName, + heapSize, + singleRequestSize, maxHeapPercentForQueueSetting, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + timeSeriesNodeStateManager, + context ); this.checkpoint = checkpoint; this.indexName = indexName; @@ -133,7 +131,7 @@ protected ActionListener getResponseListener(List getResponseListener(List modelState, boolean forceWrite, RequestPriority priority) { + public void write(ModelState modelState, boolean forceWrite, RequestPriority priority) { Instant instant = modelState.getLastCheckpointTime(); if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { return; } if (modelState.getModel() != null) { - String detectorId = modelState.getId(); + String configId = modelState.getConfigId(); String modelId = modelState.getModelId(); - if (modelId == null || detectorId == null) { + if (modelId == null || configId == null) { return; } - nodeStateManager.getConfig(detectorId, AnalysisType.AD, onGetDetector(detectorId, modelId, modelState, priority)); + nodeStateManager.getConfig(configId, context, onGetConfig(configId, modelId, modelState, priority)); } } - private ActionListener> onGetDetector( - String detectorId, + private ActionListener> onGetConfig( + String configId, String modelId, - ModelState modelState, + ModelState modelState, RequestPriority priority ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); return; } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); + Config config = configOptional.get(); try { Map source = checkpoint.toIndexSource(modelState); @@ -192,8 +190,8 @@ private ActionListener> onGetDetector( modelState.setLastCheckpointTime(clock.instant()); CheckpointWriteRequest request = new CheckpointWriteRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + configId, priority, // If the document does not already exist, the contents of the upsert element // are inserted as a new document. @@ -214,20 +212,20 @@ private ActionListener> onGetDetector( LOG.error(new ParameterizedMessage("Exception while serializing models for [{}]", modelId), e); } - }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + }, exception -> { LOG.error(new ParameterizedMessage("fail to get config [{}]", configId), exception); }); } - public void writeAll(List> modelStates, String detectorId, boolean forceWrite, RequestPriority priority) { - ActionListener> onGetForAll = ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + public void writeAll(List> modelStates, String configId, boolean forceWrite, RequestPriority priority) { + ActionListener> onGetForAll = ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); return; } - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); + Config config = configOptional.get(); try { List allRequests = new ArrayList<>(); - for (ModelState state : modelStates) { + for (ModelState state : modelStates) { Instant instant = state.getLastCheckpointTime(); if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { continue; @@ -245,8 +243,8 @@ public void writeAll(List> modelStates, String detectorI allRequests .add( new CheckpointWriteRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + configId, priority, // If the document does not already exist, the contents of the upsert element // are inserted as a new document. @@ -266,11 +264,11 @@ public void writeAll(List> modelStates, String detectorI // As we are gonna retry serializing either when the entity is // evicted out of cache or during the next maintenance period, // don't do anything when the exception happens. - LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", detectorId), e); + LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", configId), e); } - }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + }, exception -> { LOG.error(new ParameterizedMessage("fail to get config [{}]", configId), exception); }); - nodeStateManager.getConfig(detectorId, AnalysisType.AD, onGetForAll); + nodeStateManager.getConfig(configId, context, onGetForAll); } } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java new file mode 100644 index 000000000..703360a3f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.model.IndexableResult; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ColdEntityWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, RCFResultType extends IntermediateResult, ModelManagerType extends ModelManager, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, CheckpointReadWorkerType extends CheckpointReadWorker> + extends ScheduledWorker { + + public ColdEntityWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + String threadPoolName, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + CheckpointReadWorkerType checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Setting checkpointReadBatchSizeSetting, + Setting expectedColdEntityExecutionMillsSetting, + AnalysisType context + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + threadPoolName, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointReadQueue, + stateTtl, + nodeStateManager, + context + ); + + this.batchSize = checkpointReadBatchSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointReadBatchSizeSetting, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = expectedColdEntityExecutionMillsSetting.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(expectedColdEntityExecutionMillsSetting, it -> this.expectedExecutionTimeInMilliSecsPerRequest = it); + } + + @Override + protected List transformRequests(List requests) { + // guarantee we only send low priority requests + return requests.stream().filter(request -> request.getPriority() == RequestPriority.LOW).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java new file mode 100644 index 000000000..088da3157 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java @@ -0,0 +1,204 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.ExceptionUtil; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class ColdStartWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache, IndexableResultType extends IndexableResult, IntermediateResultType extends IntermediateResult, ModelManagerType extends ModelManager, SaveResultStrategyType extends SaveResultStrategy> + extends SingleRequestWorker { + private static final Logger LOG = LogManager.getLogger(ColdStartWorker.class); + + protected final ColdStarterType coldStarter; + protected final CacheType cacheProvider; + private final ModelManagerType modelManager; + private final SaveResultStrategyType resultSaver; + + public ColdStartWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + String threadPoolName, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrency, + Duration executionTtl, + ColdStarterType coldStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + CacheType cacheProvider, + AnalysisType context, + ModelManagerType modelManager, + SaveResultStrategyType resultSaver + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + threadPoolName, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + concurrency, + executionTtl, + stateTtl, + nodeStateManager, + context + ); + this.coldStarter = coldStarter; + this.cacheProvider = cacheProvider; + this.modelManager = modelManager; + this.resultSaver = resultSaver; + } + + @Override + protected void executeRequest(FeatureRequest coldStartRequest, ActionListener listener) { + String configId = coldStartRequest.getConfigId(); + + String modelId = coldStartRequest.getModelId(); + + if (null == modelId) { + String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); + LOG.warn(error); + listener.onFailure(new RuntimeException(error)); + return; + } + ModelState modelState = createEmptyState(coldStartRequest, modelId, configId); + + ActionListener>> coldStartListener = ActionListener.wrap(r -> { + nodeStateManager.getConfig(configId, context, ActionListener.wrap(configOptional -> { + try { + if (!configOptional.isPresent()) { + LOG + .error( + new ParameterizedMessage( + "fail to load trained model [{}] to cache due to the config not being found.", + modelState.getModelId() + ) + ); + return; + } + Config config = configOptional.get(); + + // score the current feature if training suceeded + if (modelState.getModel().isPresent()) { + String taskId = coldStartRequest.getTaskId(); + if (r != null) { + for (int i = 0; i < r.size(); i++) { + Entry entry = r.get(i); + IndexableResultType trainingResult = createIndexableResult( + config, + taskId, + modelId, + entry, + coldStartRequest.getEntity() + ); + resultSaver.saveResult(trainingResult, config); + } + } + + long dataStartTime = coldStartRequest.getDataStartTimeMillis(); + Sample currentSample = new Sample( + coldStartRequest.getCurrentFeature(), + Instant.ofEpochMilli(dataStartTime), + Instant.ofEpochMilli(dataStartTime + config.getIntervalInMilliseconds()) + ); + IntermediateResultType result = modelManager + .getResult(currentSample, modelState, modelId, coldStartRequest.getEntity(), config, taskId); + resultSaver.saveResult(result, config, coldStartRequest, modelId); + } + + // only load model to memory for real time analysis that has no task id + if (null == coldStartRequest.getTaskId()) { + cacheProvider.hostIfPossible(configOptional.get(), modelState); + } + + } finally { + listener.onResponse(null); + } + }, listener::onFailure)); + + }, e -> { + try { + if (ExceptionUtil.isOverloaded(e)) { + LOG.error("OpenSearch is overloaded"); + setCoolDownStart(); + } + nodeStateManager.setException(configId, e); + } finally { + listener.onFailure(e); + } + }); + + coldStarter.trainModel(coldStartRequest, configId, modelState, coldStartListener); + } + + protected abstract ModelState createEmptyState(FeatureRequest coldStartRequest, String modelId, String configId); + + protected abstract IndexableResultType createIndexableResult( + Config config, + String taskId, + String modelId, + Entry entry, + Optional entity + ); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java similarity index 92% rename from src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java index 3df70c935..45f1e424b 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -23,8 +23,8 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; /** @@ -53,7 +53,7 @@ public abstract class ConcurrentWorker extend * rate AD's usage on ES threadpools. * @param clusterService Cluster service accessor * @param random Random number generator - * @param adCircuitBreakerService AD Circuit breaker service + * @param circuitBreakerService Circuit breaker service * @param threadPool threadpool accessor * @param settings Cluster settings getter * @param maxQueuedTaskRatio maximum queued tasks ratio in ES threadpools @@ -74,8 +74,9 @@ public ConcurrentWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -85,7 +86,8 @@ public ConcurrentWorker( Setting concurrencySetting, Duration executionTtl, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( queueName, @@ -94,8 +96,9 @@ public ConcurrentWorker( maxHeapPercentForQueueSetting, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -103,7 +106,8 @@ public ConcurrentWorker( lowSegmentPruneRatio, maintenanceFreqConstant, stateTtl, - nodeStateManager + nodeStateManager, + context ); this.permits = new Semaphore(concurrencySetting.get(settings)); @@ -132,7 +136,7 @@ public void maintenance() { */ @Override protected void triggerProcess() { - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { + threadPool.executor(threadPoolName).execute(() -> { if (permits.tryAcquire()) { try { lastExecuteTime = clock.instant(); diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java new file mode 100644 index 000000000..0749381f4 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.util.Optional; + +import org.opensearch.timeseries.model.Entity; + +public class FeatureRequest extends QueuedRequest { + private final double[] currentFeature; + private final long dataStartTimeMillis; + protected final String modelId; + private final Optional entity; + private final String taskId; + + // used in HC + public FeatureRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + double[] currentFeature, + long dataStartTimeMs, + Entity entity, + String taskId + ) { + super(expirationEpochMs, configId, priority); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + this.modelId = entity.getModelId(configId).isEmpty() ? null : entity.getModelId(configId).get(); + this.entity = Optional.ofNullable(entity); + this.taskId = taskId; + } + + // used in single-stream + public FeatureRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + String modelId, + double[] currentFeature, + long dataStartTimeMs, + String taskId + ) { + super(expirationEpochMs, configId, priority); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + this.modelId = modelId; + this.entity = Optional.empty(); + this.taskId = taskId; + } + + public double[] getCurrentFeature() { + return currentFeature; + } + + public long getDataStartTimeMillis() { + return dataStartTimeMillis; + } + + public String getModelId() { + return modelId; + } + + public Optional getEntity() { + return entity; + } + + public String getTaskId() { + return taskId; + } + + public boolean isRunOnce() { + return taskId != null; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java similarity index 77% rename from src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java index 66c440db9..a13a490de 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java @@ -9,22 +9,22 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public abstract class QueuedRequest { protected long expirationEpochMs; - protected String detectorId; + protected String configId; protected RequestPriority priority; /** * * @param expirationEpochMs Request expiry time in milliseconds - * @param detectorId Detector Id + * @param configId Detector Id * @param priority how urgent the request is */ - protected QueuedRequest(long expirationEpochMs, String detectorId, RequestPriority priority) { + protected QueuedRequest(long expirationEpochMs, String configId, RequestPriority priority) { this.expirationEpochMs = expirationEpochMs; - this.detectorId = detectorId; + this.configId = configId; this.priority = priority; } @@ -47,11 +47,11 @@ public void setPriority(RequestPriority priority) { this.priority = priority; } - public String getId() { - return detectorId; + public String getConfigId() { + return configId; } public void setDetectorId(String detectorId) { - this.detectorId = detectorId; + this.configId = detectorId; } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java similarity index 89% rename from src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java index 911ae43a5..93df5b1ae 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; +import static org.opensearch.timeseries.settings.TimeSeriesSettings.COOLDOWN_MINUTES; import java.time.Clock; import java.time.Duration; @@ -39,10 +39,10 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPoolStats; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.ExpiringState; import org.opensearch.timeseries.MaintenanceState; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.TimeSeriesException; @@ -156,6 +156,15 @@ public int clearExpiredRequests() { } return removed; } + + public boolean hasConfigId(String configId) { + for (RequestType request : content) { + if (configId.equals(request.getConfigId())) { + return true; + } + } + return false; + } } private static final Logger LOG = LogManager.getLogger(RateLimitedRequestWorker.class); @@ -175,8 +184,9 @@ public int clearExpiredRequests() { protected final ConcurrentSkipListMap requestQueues; private String lastSelectedRequestQueueId; protected Random random; - private CircuitBreakerService adCircuitBreakerService; + private CircuitBreakerService circuitBreakerService; protected ThreadPool threadPool; + protected String threadPoolName; protected Instant cooldownStart; protected int coolDownMinutes; private float maxQueuedTaskRatio; @@ -186,6 +196,7 @@ public int clearExpiredRequests() { protected int maintenanceFreqConstant; private final Duration stateTtl; protected final NodeStateManager nodeStateManager; + protected final AnalysisType context; public RateLimitedRequestWorker( String workerName, @@ -194,8 +205,9 @@ public RateLimitedRequestWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - CircuitBreakerService adCircuitBreakerService, + CircuitBreakerService circuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -203,7 +215,8 @@ public RateLimitedRequestWorker( float lowRequestQueuePruneRatio, int maintenanceFreqConstant, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { this.heapSize = heapSizeInBytes; this.singleRequestSize = singleRequestSizeInBytes; @@ -218,8 +231,9 @@ public RateLimitedRequestWorker( this.workerName = workerName; this.random = random; - this.adCircuitBreakerService = adCircuitBreakerService; + this.circuitBreakerService = circuitBreakerService; this.threadPool = threadPool; + this.threadPoolName = threadPoolName; this.maxQueuedTaskRatio = maxQueuedTaskRatio; this.clock = clock; this.mediumRequestQueuePruneRatio = mediumRequestQueuePruneRatio; @@ -228,22 +242,24 @@ public RateLimitedRequestWorker( this.lastSelectedRequestQueueId = null; this.requestQueues = new ConcurrentSkipListMap<>(); this.cooldownStart = Instant.MIN; - this.coolDownMinutes = (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()); + this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); this.maintenanceFreqConstant = maintenanceFreqConstant; this.stateTtl = stateTtl; this.nodeStateManager = nodeStateManager; + this.context = context; } - protected String getWorkerName() { + public String getWorkerName() { return workerName; } /** - * To add fairness to multiple detectors, HCAD allocates queues at a per - * detector granularity and pulls off requests across similar queues in a - * round-robin fashion. This way, if one detector has a much higher - * cardinality than other detectors, the unfinished portion of that - * detector’s workload times out, and other detectors’ workloads continue + * To add fairness to multiple analyses, HC allocates queues at a per + * analysis (e.g., detector or forecaster) granularity and pulls off + * requests across similar queues in a round-robin fashion. + * This way, if one analysis has a much higher + * cardinality than other analysis, the unfinished portion of that + * analysis's workload times out, and other analyses’ workloads continue * operating with predictable performance. For example, for loading checkpoints, * HCAD pulls off 10 requests from one detector’ queues, issues a mget request * to ES, wait for it to finish, and then does it again for other detectors’ @@ -305,7 +321,7 @@ protected void putOnly(RequestType request) { // just use the RequestQueue priority (i.e., low or high) as the key of the RequestQueue map. RequestQueue requestQueue = requestQueues .computeIfAbsent( - RequestPriority.MEDIUM == request.getPriority() ? request.getId() : request.getPriority().name(), + RequestPriority.MEDIUM == request.getPriority() ? request.getConfigId() : request.getPriority().name(), k -> new RequestQueue() ); @@ -429,7 +445,7 @@ private void maintainForMemory() { int exceededSize = exceededSize(); if (exceededSize > 0) { prune(requestQueues, exceededSize); - } else if (adCircuitBreakerService.isOpen()) { + } else if (circuitBreakerService.isOpen()) { // remove a few items in each RequestQueue prune(requestQueues); } @@ -551,7 +567,7 @@ protected void process() { } catch (Exception e) { LOG.error(new ParameterizedMessage("Fail to process requests in [{}].", this.workerName), e); } - }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), threadPoolName); } else { try { triggerProcess(); @@ -566,6 +582,29 @@ protected void process() { } } + /** + * + * @param configId Config Id + * @return whether there is any unfinished request belonging to a configId + */ + public boolean hasConfigId(String configId) { + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + String requestId = requestQueueEntry.getKey(); + if (requestId.equals(RequestPriority.LOW.name()) || requestId.equals(RequestPriority.HIGH.name())) { + RequestQueue requests = requestQueueEntry.getValue(); + if (requests.hasConfigId(configId)) { + return true; + } + } else { + // requestId is config Id + if (requestId.equals(configId)) { + return true; + } + } + } + return false; + } + /** * How to execute requests is abstracted out and left to RateLimitedQueue's subclasses to implement. */ diff --git a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java b/src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java similarity index 88% rename from src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java rename to src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java index 3193d2285..29fb14523 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public enum RequestPriority { LOW, diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java new file mode 100644 index 000000000..9eba5d1b5 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.io.IOException; + +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.timeseries.model.IndexableResult; + +public abstract class ResultWriteRequest extends QueuedRequest implements Writeable { + private final ResultType result; + // If resultIndex is null, result will be stored in default result index. + private final String resultIndex; + + public ResultWriteRequest(long expirationEpochMs, String detectorId, RequestPriority priority, ResultType result, String resultIndex) { + super(expirationEpochMs, detectorId, priority); + this.result = result; + this.resultIndex = resultIndex; + } + + /** + * + * @param subclass type + * @param result type + * @param expirationEpochMs expiration epoch in milliseconds + * @param configId config id + * @param priority request priority + * @param result result + * @param resultIndex result index + * @param clazz The clazz parameter is used to pass the class object of the desired subtype, which allows us to perform a dynamic cast to T and return the correctly-typed instance. + * @return + */ + public static , R extends IndexableResult> T create( + long expirationEpochMs, + String configId, + RequestPriority priority, + IndexableResult result, + String resultIndex, + Class clazz + ) { + if (result instanceof AnomalyResult) { + return clazz.cast(new ADResultWriteRequest(expirationEpochMs, configId, priority, (AnomalyResult) result, resultIndex)); + } else if (result instanceof ForecastResult) { + return clazz.cast(new ForecastResultWriteRequest(expirationEpochMs, configId, priority, (ForecastResult) result, resultIndex)); + } else { + throw new IllegalArgumentException("Unsupported result type"); + } + } + + public ResultWriteRequest(StreamInput in, Writeable.Reader resultReader) throws IOException { + this.result = resultReader.read(in); + this.resultIndex = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + result.writeTo(out); + out.writeOptionalString(resultIndex); + } + + public ResultType getResult() { + return result; + } + + public String getResultIndex() { + return resultIndex; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java similarity index 59% rename from src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java index 02152b086..faaf7852e 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java @@ -1,19 +1,11 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; +import java.io.IOException; import java.time.Clock; import java.time.Duration; import java.util.List; @@ -25,12 +17,8 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.DocWriteRequest; import org.opensearch.action.index.IndexRequest; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedFunction; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -44,63 +32,78 @@ import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.transport.ResultBulkRequest; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; import org.opensearch.timeseries.util.ExceptionUtil; -public class ResultWriteWorker extends BatchWorker { +public abstract class ResultWriteWorker, BatchRequestType extends ResultBulkRequest, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, ResultHandlerType extends IndexMemoryPressureAwareResultHandler> + extends BatchWorker { private static final Logger LOG = LogManager.getLogger(ResultWriteWorker.class); - public static final String WORKER_NAME = "result-write"; - - private final MultiEntityResultHandler resultHandler; - private NamedXContentRegistry xContentRegistry; + protected final ResultHandlerType resultHandler; + protected NamedXContentRegistry xContentRegistry; + private CheckedFunction resultParser; public ResultWriteWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, + String queueName, + long heapSize, + int singleRequestSize, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, + Setting concurrencySetting, Duration executionTtl, - MultiEntityResultHandler resultHandler, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager timeSeriesNodeStateManager, + ResultHandlerType resultHandler, NamedXContentRegistry xContentRegistry, - NodeStateManager stateManager, - Duration stateTtl + CheckedFunction resultParser, + AnalysisType context ) { super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, + queueName, + heapSize, + singleRequestSize, maxHeapPercentForQueueSetting, clusterService, random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_RESULT_WRITE_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + timeSeriesNodeStateManager, + context ); this.resultHandler = resultHandler; this.xContentRegistry = xContentRegistry; + this.resultParser = resultParser; } @Override - protected void executeBatchRequest(ADResultBulkRequest request, ActionListener listener) { + protected void executeBatchRequest(BatchRequestType request, ActionListener listener) { if (request.numberOfActions() < 1) { listener.onResponse(null); return; @@ -109,19 +112,7 @@ protected void executeBatchRequest(ADResultBulkRequest request, ActionListener toProcess) { - final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); - for (ResultWriteRequest request : toProcess) { - bulkRequest.add(request); - } - return bulkRequest; - } - - @Override - protected ActionListener getResponseListener( - List toProcess, - ADResultBulkRequest bulkRequest - ) { + protected ActionListener getResponseListener(List toProcess, BatchRequestType bulkRequest) { return ActionListener.wrap(adResultBulkResponse -> { if (adResultBulkResponse == null || false == adResultBulkResponse.getRetryRequests().isPresent()) { // all successful @@ -134,12 +125,12 @@ protected ActionListener getResponseListener( // retry all of them super.putAll(toProcess); } else if (ExceptionUtil.isOverloaded(exception)) { - LOG.error("too many get AD model checkpoint requests or shard not avialble"); + LOG.error("too many get model checkpoint requests or shard not avialble"); setCoolDownStart(); } - for (ResultWriteRequest request : toProcess) { - nodeStateManager.setException(request.getId(), exception); + for (ResultWriteRequestType request : toProcess) { + nodeStateManager.setException(request.getConfigId(), exception); } LOG.error("Fail to save results", exception); }); @@ -150,50 +141,18 @@ private void enqueueRetryRequestIteration(List requestToRetry, int return; } DocWriteRequest currentRequest = requestToRetry.get(index); - Optional resultToRetry = getAnomalyResult(currentRequest); + Optional resultToRetry = getResult(currentRequest); if (false == resultToRetry.isPresent()) { enqueueRetryRequestIteration(requestToRetry, index + 1); return; } - AnomalyResult result = resultToRetry.get(); - String detectorId = result.getConfigId(); - nodeStateManager.getConfig(detectorId, AnalysisType.AD, onGetDetector(requestToRetry, index, detectorId, result)); - } - private ActionListener> onGetDetector( - List requestToRetry, - int index, - String detectorId, - AnomalyResult resultToRetry - ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); - enqueueRetryRequestIteration(requestToRetry, index + 1); - return; - } - - AnomalyDetector detector = (AnomalyDetector) detectorOptional.get(); - super.put( - new ResultWriteRequest( - // expire based on execute start time - resultToRetry.getExecutionStartTime().toEpochMilli() + detector.getIntervalInMilliseconds(), - detectorId, - resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, - resultToRetry, - detector.getCustomResultIndex() - ) - ); - - enqueueRetryRequestIteration(requestToRetry, index + 1); - - }, exception -> { - LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); - enqueueRetryRequestIteration(requestToRetry, index + 1); - }); + ResultType result = resultToRetry.get(); + String id = result.getConfigId(); + nodeStateManager.getConfig(id, context, onGetConfig(requestToRetry, index, id, result)); } - private Optional getAnomalyResult(DocWriteRequest request) { + protected Optional getResult(DocWriteRequest request) { try { if (false == (request instanceof IndexRequest)) { LOG.error(new ParameterizedMessage("We should only send IndexRquest, but get [{}].", request)); @@ -211,11 +170,52 @@ private Optional getAnomalyResult(DocWriteRequest request) { // org.opensearch.core.common.ParsingException: Failed to parse object: expecting token of type [START_OBJECT] but found // [null] xContentParser.nextToken(); - return Optional.of(AnomalyResult.parse(xContentParser)); + return Optional.of(resultParser.apply(xContentParser)); } } catch (Exception e) { LOG.error(new ParameterizedMessage("Fail to parse index request [{}]", request), e); } return Optional.empty(); } + + private ActionListener> onGetConfig( + List requestToRetry, + int index, + String id, + ResultType resultToRetry + ) { + return ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", id)); + enqueueRetryRequestIteration(requestToRetry, index + 1); + return; + } + + Config config = configOptional.get(); + super.put( + createResultWriteRequest( + // expire based on execute start time + resultToRetry.getExecutionStartTime().toEpochMilli() + config.getIntervalInMilliseconds(), + id, + resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, + resultToRetry, + config.getCustomResultIndex() + ) + ); + + enqueueRetryRequestIteration(requestToRetry, index + 1); + + }, exception -> { + LOG.error(new ParameterizedMessage("fail to get config [{}]", id), exception); + enqueueRetryRequestIteration(requestToRetry, index + 1); + }); + } + + protected abstract ResultWriteRequestType createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + ResultType result, + String resultIndex + ); } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/SaveResultStrategy.java b/src/main/java/org/opensearch/timeseries/ratelimit/SaveResultStrategy.java new file mode 100644 index 000000000..d5c907c16 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/SaveResultStrategy.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Instant; +import java.util.Optional; + +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; + +public interface SaveResultStrategy> { + void saveResult(RCFResultType result, Config config, FeatureRequest origRequest, String modelId); + + void saveResult( + RCFResultType result, + Config config, + Instant dataStart, + Instant dataEnd, + String modelId, + double[] currentData, + Optional entity, + String taskId + ); + + void saveResult(IndexableResultType result, Config config); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java similarity index 90% rename from src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java index 115d79882..04dfdd900 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -18,18 +18,19 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; public abstract class ScheduledWorker extends RateLimitedRequestWorker { - private static final Logger LOG = LogManager.getLogger(ColdEntityWorker.class); + private static final Logger LOG = LogManager.getLogger(ADColdEntityWorker.class); // the number of requests forwarded to the target queue protected volatile int batchSize; @@ -47,6 +48,7 @@ public ScheduledWorker( Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -55,7 +57,8 @@ public ScheduledWorker( int maintenanceFreqConstant, RateLimitedRequestWorker targetQueue, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( workerName, @@ -66,6 +69,7 @@ public ScheduledWorker( random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -73,7 +77,8 @@ public ScheduledWorker( lowSegmentPruneRatio, maintenanceFreqConstant, stateTtl, - nodeStateManager + nodeStateManager, + context ); this.targetQueue = targetQueue; @@ -114,7 +119,7 @@ private void pullRequests() { private synchronized void schedulePulling(TimeValue delay) { try { - threadPool.schedule(this::pullRequests, delay, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + threadPool.schedule(this::pullRequests, delay, threadPoolName); } catch (Exception e) { LOG.error("Fail to schedule cold entity pulling", e); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java similarity index 92% rename from src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java index e789e36fa..9b11db99c 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -24,6 +24,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.breaker.CircuitBreakerService; @@ -39,6 +40,7 @@ public SingleRequestWorker( Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, + String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, @@ -48,7 +50,8 @@ public SingleRequestWorker( Setting concurrencySetting, Duration executionTtl, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( queueName, @@ -59,6 +62,7 @@ public SingleRequestWorker( random, adCircuitBreakerService, threadPool, + threadPoolName, settings, maxQueuedTaskRatio, clock, @@ -68,7 +72,8 @@ public SingleRequestWorker( concurrencySetting, executionTtl, stateTtl, - nodeStateManager + nodeStateManager, + context ); } diff --git a/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java b/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java new file mode 100644 index 000000000..f31e6ce0c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.model.DateRange; + +import com.google.common.collect.ImmutableList; + +public abstract class RestJobAction extends BaseRestHandler { + protected DateRange parseInputDateRange(RestRequest request) throws IOException { + if (!request.hasContent()) { + return null; + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + DateRange dateRange = DateRange.parse(parser); + return dateRange; + } + + @Override + public List routes() { + return ImmutableList.of(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/RestStatsAction.java b/src/main/java/org/opensearch/timeseries/rest/RestStatsAction.java new file mode 100644 index 000000000..bb1585566 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/RestStatsAction.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.TreeSet; + +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.Strings; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public abstract class RestStatsAction extends BaseRestHandler { + private Stats timeSeriesStats; + private DiscoveryNodeFilterer nodeFilter; + + /** + * Constructor + * + * @param timeSeriesStats TimeSeriesStats object + * @param nodeFilter util class to get eligible data nodes + */ + public RestStatsAction(Stats timeSeriesStats, DiscoveryNodeFilterer nodeFilter) { + this.timeSeriesStats = timeSeriesStats; + this.nodeFilter = nodeFilter; + } + + /** + * Creates a StatsRequest from a RestRequest + * + * @param request RestRequest + * @return StatsRequest Request containing stats to be retrieved + */ + protected StatsRequest getRequest(RestRequest request) { + // parse the nodes the user wants to query the stats for + String nodesIdsStr = request.param("nodeId"); + Set validStats = timeSeriesStats.getStats().keySet(); + + StatsRequest statsRequest = null; + if (!Strings.isEmpty(nodesIdsStr)) { + String[] nodeIdsArr = nodesIdsStr.split(","); + statsRequest = new StatsRequest(nodeIdsArr); + } else { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + statsRequest = new StatsRequest(dataNodes); + } + + statsRequest.timeout(request.param("timeout")); + + // parse the stats the user wants to see + HashSet statsSet = null; + String statsStr = request.param("stat"); + if (!Strings.isEmpty(statsStr)) { + statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); + } + + if (statsSet == null) { + statsRequest.addAll(validStats); // retrieve all stats if none are specified + } else if (statsSet.size() == 1 && statsSet.contains(StatsRequest.ALL_STATS_KEY)) { + statsRequest.addAll(validStats); + } else if (statsSet.contains(StatsRequest.ALL_STATS_KEY)) { + throw new IllegalArgumentException( + "Request " + request.path() + " contains " + StatsRequest.ALL_STATS_KEY + " and individual stats" + ); + } else { + Set invalidStats = new TreeSet<>(); + for (String stat : statsSet) { + if (validStats.contains(stat)) { + statsRequest.addStat(stat); + } else { + invalidStats.add(stat); + } + } + + if (!invalidStats.isEmpty()) { + throw new IllegalArgumentException(unrecognized(request, invalidStats, statsRequest.getStatsToBeRetrieved(), "stat")); + } + } + return statsRequest; + } + +} diff --git a/src/main/java/org/opensearch/timeseries/rest/RestValidateAction.java b/src/main/java/org/opensearch/timeseries/rest/RestValidateAction.java new file mode 100644 index 000000000..fa546c3d9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/RestValidateAction.java @@ -0,0 +1,117 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; + +/** + * This class consists of the REST handler to validate anomaly detector configurations. + */ +public class RestValidateAction { + private AnalysisType context; + private Integer maxSingleStreamConfigs; + private Integer maxHCConfigs; + private Integer maxFeatures; + private Integer maxCategoricalFields; + private TimeValue requestTimeout; + + public RestValidateAction( + AnalysisType context, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Integer maxFeatures, + Integer maxCategoricalFields, + TimeValue requestTimeout + ) { + this.context = context; + this.maxSingleStreamConfigs = maxSingleStreamConfigs; + this.maxHCConfigs = maxHCConfigs; + this.maxFeatures = maxFeatures; + this.maxCategoricalFields = maxCategoricalFields; + this.requestTimeout = requestTimeout; + } + + public void sendValidationParseResponse(ConfigValidationIssue issue, RestChannel channel) throws IOException { + try { + BytesRestResponse restResponse = new BytesRestResponse( + RestStatus.OK, + new ValidateConfigResponse(issue).toXContent(channel.newBuilder()) + ); + channel.sendResponse(restResponse); + } catch (Exception e) { + channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } + } + + private Boolean validationTypesAreAccepted(String validationType) { + Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); + return (!Collections.disjoint(typesInRequest, AbstractTimeSeriesActionHandler.ALL_VALIDATION_ASPECTS_STRS)); + } + + public ValidateConfigRequest prepareRequest(RestRequest request, NodeClient client, String typesStr) throws IOException { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + // if type param isn't blank and isn't a part of possible validation types throws exception + if (!StringUtils.isBlank(typesStr)) { + if (!validationTypesAreAccepted(typesStr)) { + throw new IllegalStateException(CommonMessages.NOT_EXISTENT_VALIDATION_TYPE); + } + } + + Config config = null; + + if (context.isAD()) { + config = AnomalyDetector.parse(parser); + } else if (context.isForecast()) { + config = Forecaster.parse(parser); + } else { + throw new UnsupportedOperationException("This method is not supported"); + } + + ValidateConfigRequest validateAnomalyDetectorRequest = new ValidateConfigRequest( + context, + config, + typesStr, + maxSingleStreamConfigs, + maxHCConfigs, + maxFeatures, + requestTimeout, + maxCategoricalFields + ); + return validateAnomalyDetectorRequest; + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java new file mode 100644 index 000000000..174a2bdbb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java @@ -0,0 +1,862 @@ +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG; +import static org.opensearch.timeseries.util.ParseUtils.parseAggregators; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.isExceptionCausedByInvalidQuery; + +import java.io.IOException; +import java.time.Clock; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsAction; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.MergeableList; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public abstract class AbstractTimeSeriesActionHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager> + implements + Processor { + + protected final Logger logger = LogManager.getLogger(AbstractTimeSeriesActionHandler.class); + + public static final String NAME_REGEX = "[a-zA-Z0-9._-]+"; + public static final Integer MAX_NAME_SIZE = 64; + public static final String CATEGORY_NOT_FOUND_ERR_MSG = "Can't find the categorical field %s"; + + public static String INVALID_NAME_SIZE = "Name should be shortened. The maximum limit is " + + AbstractTimeSeriesActionHandler.MAX_NAME_SIZE + + " characters."; + + public static final Set ALL_VALIDATION_ASPECTS_STRS = Arrays + .asList(ValidationAspect.values()) + .stream() + .map(aspect -> aspect.getName()) + .collect(Collectors.toSet()); + + protected final Config config; + protected final IndexManagement timeSeriesIndices; + protected final boolean isDryRun; + protected final Client client; + protected final String id; + protected final SecurityClientUtil clientUtil; + protected final User user; + protected final RestRequest.Method method; + protected final ConfigUpdateConfirmer handler; + protected final ClusterService clusterService; + protected final NamedXContentRegistry xContentRegistry; + protected final TimeValue requestTimeout; + protected final WriteRequest.RefreshPolicy refreshPolicy; + protected final Long seqNo; + protected final Long primaryTerm; + protected final String validationType; + protected final SearchFeatureDao searchFeatureDao; + protected final Integer maxFeatures; + protected final Integer maxCategoricalFields; + protected final AnalysisType context; + protected final List batchTasks; + protected final boolean canUpdateEverything; + + protected final Integer maxSingleStreamConfigs; + protected final Integer maxHCConfigs; + protected final Clock clock; + protected final Settings settings; + + public AbstractTimeSeriesActionHandler( + Config config, + IndexManagement timeSeriesIndices, + boolean isDryRun, + Client client, + String id, + SecurityClientUtil clientUtil, + User user, + RestRequest.Method method, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + TransportService transportService, + TimeValue requestTimeout, + WriteRequest.RefreshPolicy refreshPolicy, + Long seqNo, + Long primaryTerm, + String validationType, + SearchFeatureDao searchFeatureDao, + Integer maxFeatures, + Integer maxCategoricalFields, + AnalysisType context, + TaskManagerType taskManager, + List batchTasks, + boolean canUpdateCategoryField, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Clock clock, + Settings settings + ) { + this.config = config; + this.timeSeriesIndices = timeSeriesIndices; + this.isDryRun = isDryRun; + this.client = client; + this.id = id == null ? "" : id; + this.clientUtil = clientUtil; + this.user = user; + this.method = method; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.requestTimeout = requestTimeout; + this.refreshPolicy = refreshPolicy; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.validationType = validationType; + this.searchFeatureDao = searchFeatureDao; + this.maxFeatures = maxFeatures; + this.maxCategoricalFields = maxCategoricalFields; + this.context = context; + this.batchTasks = batchTasks; + this.canUpdateEverything = canUpdateCategoryField; + this.maxSingleStreamConfigs = maxSingleStreamConfigs; + this.maxHCConfigs = maxHCConfigs; + this.clock = clock; + this.settings = settings; + this.handler = new ConfigUpdateConfirmer<>(taskManager, transportService); + } + + /** + * Start function to process create/update/validate config request. + * + * If validation type is detector/forecaster then all validation in this class involves validation + * checks against the configurations. + * Any issues raised here would block user from creating the config (e.g., anomaly detector). + * If validation Aspect is of type model then further non-blocker validation will be executed + * after the blocker validation is executed. Any issues that are raised for model validation + * are simply warnings for the user in terms of how configuration could be changed to lead to + * a higher likelihood of model training completing successfully. + * + * For custom index validation, if config is not using custom result index, check if config + * index exist first, if not, will create first. Otherwise, check if custom + * result index exists or not. If exists, will check if index mapping matches + * config result index mapping and if user has correct permission to write index. + * If doesn't exist, will create custom result index with result index + * mapping. + */ + @Override + public void start(ActionListener listener) { + String resultIndex = config.getCustomResultIndex(); + // use default detector result index which is system index + if (resultIndex == null) { + createOrUpdateConfig(listener); + return; + } + + if (this.isDryRun) { + if (timeSeriesIndices.doesIndexExist(resultIndex)) { + timeSeriesIndices + .validateResultIndexAndExecute( + resultIndex, + () -> createOrUpdateConfig(listener), + false, + ActionListener.wrap(r -> createOrUpdateConfig(listener), ex -> { + logger.error(ex); + listener.onFailure(createValidationException(ex.getMessage(), ValidationIssueType.RESULT_INDEX)); + return; + }) + ); + return; + } else { + createOrUpdateConfig(listener); + return; + } + } + // use custom result index if not validating and resultIndex not null + timeSeriesIndices.initCustomResultIndexAndExecute(resultIndex, () -> createOrUpdateConfig(listener), listener); + } + + // if isDryRun is true then this method is being executed through Validation API meaning actual + // index won't be created, only validation checks will be executed throughout the class + private void createOrUpdateConfig(ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (!timeSeriesIndices.doesConfigIndexExist() && !this.isDryRun) { + logger.info("Config Indices do not exist"); + timeSeriesIndices + .initConfigIndex( + ActionListener + .wrap( + response -> onCreateMappingsResponse(response, false, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + logger.info("DryRun variable " + this.isDryRun); + validateName(this.isDryRun, listener); + } + } catch (Exception e) { + logger.error("Failed to create or update forecaster " + id, e); + listener.onFailure(e); + } + } + + protected void validateName(boolean indexingDryRun, ActionListener listener) { + if (!config.getName().matches(NAME_REGEX)) { + listener.onFailure(createValidationException(CommonMessages.INVALID_NAME, ValidationIssueType.NAME)); + return; + + } + if (config.getName().length() > MAX_NAME_SIZE) { + listener.onFailure(createValidationException(AbstractTimeSeriesActionHandler.INVALID_NAME_SIZE, ValidationIssueType.NAME)); + return; + } + validateTimeField(indexingDryRun, listener); + } + + protected void validateTimeField(boolean indexingDryRun, ActionListener listener) { + String givenTimeField = config.getTimeField(); + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(config.getIndices().toArray(new String[0])).fields(givenTimeField); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + // comments explaining fieldMappingResponse parsing can be found inside validateCategoricalField(String, boolean) + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + boolean foundField = false; + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.DATE_TYPE)) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.FORECASTER + ) + ); + return; + } + } + } + } + } + } + } + if (!foundField) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.FORECASTER + ) + ); + return; + } + prepareConfigIndexing(indexingDryRun, listener); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", config.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + clientUtil + .executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, context, mappingsListener); + } + + /** + * Prepare for indexing a new config. + * @param indexingDryRun if this is dryrun for indexing; when validation, it is true; when create/update, it is false + */ + protected void prepareConfigIndexing(boolean indexingDryRun, ActionListener listener) { + if (method == RestRequest.Method.PUT) { + handler + .confirmJobRunning( + clusterService, + client, + id, + listener, + () -> updateConfig(id, indexingDryRun, listener), + xContentRegistry + ); + } else { + createConfig(indexingDryRun, listener); + } + } + + protected void updateConfig(String id, boolean indexingDryRun, ActionListener listener) { + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, id); + client + .get( + request, + ActionListener + .wrap( + response -> onGetConfigResponse(response, indexingDryRun, id, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetConfigResponse(GetResponse response, boolean indexingDryRun, String id, ActionListener listener) { + if (!response.isExists()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + id, RestStatus.NOT_FOUND)); + return; + } + try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Config existingConfig = parse(parser, response); + // If category field changed, frontend may not be able to render AD result for different config types correctly. + // For example, if an anomaly detector changed from HC to single entity detector, AD result page may show multiple anomaly + // result points on the same time point if there are multiple entities have anomaly results. + // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type + // in top N entities list. That's confusing. + // So we decide to block updating detector category field. + // for forecasting, we will not show results after forecaster configuration change (excluding changes like description) + // thus it is safe to allow updating everything. In the future, we might change AD to allow such behavior. + if (!canUpdateEverything) { + if (!ParseUtils.listEqualsWithoutConsideringOrder(existingConfig.getCategoryFields(), config.getCategoryFields())) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD, RestStatus.BAD_REQUEST)); + return; + } + if (!Objects.equals(existingConfig.getCustomResultIndex(), config.getCustomResultIndex())) { + listener + .onFailure( + new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX, RestStatus.BAD_REQUEST) + ); + return; + } + } + + ActionListener confirmBatchRunningListener = ActionListener + .wrap( + r -> searchConfigInputIndices(id, indexingDryRun, listener), + // can't update config if there is task running + listener::onFailure + ); + + handler.confirmBatchRunning(id, batchTasks, confirmBatchRunningListener); + } catch (IOException e) { + String message = "Failed to parse anomaly detector " + id; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + + } + + protected void validateAgainstExistingHCConfig(String detectorId, boolean indexingDryRun, ActionListener listener) { + if (timeSeriesIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(Config.CATEGORY_FIELD)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchHCConfigResponse(response, detectorId, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + validateCategoricalField(detectorId, indexingDryRun, listener); + } + + } + + protected void createConfig(boolean indexingDryRun, ActionListener listener) { + try { + List categoricalFields = config.getCategoryFields(); + if (categoricalFields != null && categoricalFields.size() > 0) { + validateAgainstExistingHCConfig(null, indexingDryRun, listener); + } else { + if (timeSeriesIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.matchAllQuery(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchSingleStreamConfigResponse(response, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + searchConfigInputIndices(null, indexingDryRun, listener); + } + + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void onSearchSingleStreamConfigResponse(SearchResponse response, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value >= getMaxSingleStreamConfigs()) { + String errorMsgSingleEntity = getExceedMaxSingleStreamConfigsErrorMsg(getMaxSingleStreamConfigs()); + logger.error(errorMsgSingleEntity); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsgSingleEntity, ValidationIssueType.GENERAL_SETTINGS)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsgSingleEntity)); + } else { + searchConfigInputIndices(null, indexingDryRun, listener); + } + } + + protected void onSearchHCConfigResponse(SearchResponse response, String detectorId, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value >= getMaxHCConfigs()) { + String errorMsg = getExceedMaxHCConfigsErrorMsg(getMaxHCConfigs()); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.GENERAL_SETTINGS)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateCategoricalField(detectorId, indexingDryRun, listener); + } + } + + @SuppressWarnings("unchecked") + protected void validateCategoricalField(String detectorId, boolean indexingDryRun, ActionListener listener) { + List categoryField = config.getCategoryFields(); + + if (categoryField == null) { + searchConfigInputIndices(detectorId, indexingDryRun, listener); + return; + } + + // we only support a certain number of categorical field + // If there is more fields than required, Config's constructor + // throws validation exception before reaching this line + int maxCategoryFields = maxCategoricalFields; + if (categoryField.size() > maxCategoryFields) { + listener + .onFailure( + createValidationException(CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), ValidationIssueType.CATEGORY) + ); + return; + } + + String categoryField0 = categoryField.get(0); + + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(config.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + // example getMappingsResponse: + // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', + // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} + // for nested field, it would be + // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} + boolean foundField = false; + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + // example output: + // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type != null && type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { + listener + .onFailure( + createValidationException(CATEGORICAL_FIELD_TYPE_ERR_MSG, ValidationIssueType.CATEGORY) + ); + return; + } + } + } + } + + } + } + } + + if (foundField == false) { + listener + .onFailure( + createValidationException( + String.format(Locale.ROOT, CATEGORY_NOT_FOUND_ERR_MSG, categoryField0), + ValidationIssueType.CATEGORY + ) + ); + return; + } + + searchConfigInputIndices(detectorId, indexingDryRun, listener); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", config.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + + clientUtil + .executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, context, mappingsListener); + } + + protected void searchConfigInputIndices(String detectorId, boolean indexingDryRun, ActionListener listener) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(QueryBuilders.matchAllQuery()) + .size(0) + .timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + + ActionListener searchResponseListener = ActionListener + .wrap( + searchResponse -> onSearchConfigInputIndicesResponse(searchResponse, detectorId, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ); + + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, context, searchResponseListener); + } + + protected void onSearchConfigInputIndicesResponse( + SearchResponse response, + String detectorId, + boolean indexingDryRun, + ActionListener listener + ) throws IOException { + if (response.getHits().getTotalHits().value == 0) { + String errorMsg = getNoDocsInUserIndexErrorMsg(Arrays.toString(config.getIndices().toArray(new String[0]))); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.INDICES)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateConfigFeatures(detectorId, indexingDryRun, listener); + } + } + + protected void checkConfigNameExists(String configId, boolean indexingDryRun, ActionListener listener) throws IOException { + if (timeSeriesIndices.doesConfigIndexExist()) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + // src/main/resources/mappings/config.json#L14 + boolQueryBuilder.must(QueryBuilders.termQuery("name.keyword", config.getName())); + if (StringUtils.isNotBlank(configId)) { + boolQueryBuilder.mustNot(QueryBuilders.termQuery(RestHandlerUtils._ID, configId)); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + searchResponse -> onSearchConfigNameResponse(searchResponse, config.getName(), indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + tryIndexingConfig(indexingDryRun, listener); + } + + } + + protected void onSearchConfigNameResponse(SearchResponse response, String name, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value > 0) { + String errorMsg = getDuplicateConfigErrorMsg(name); + logger.warn(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.NAME)); + } else { + listener.onFailure(new OpenSearchStatusException(errorMsg, RestStatus.CONFLICT)); + } + } else { + tryIndexingConfig(indexingDryRun, listener); + } + } + + protected void tryIndexingConfig(boolean indexingDryRun, ActionListener listener) throws IOException { + if (!indexingDryRun) { + indexConfig(id, listener); + } else { + finishConfigValidationOrContinueToModelValidation(listener); + } + } + + protected Set getValidationTypes(String validationType) { + if (StringUtils.isBlank(validationType)) { + return getDefaultValidationType(); + } else { + Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); + return ValidationAspect + .getNames(Sets.intersection(AbstractTimeSeriesActionHandler.ALL_VALIDATION_ASPECTS_STRS, typesInRequest)); + } + } + + protected void finishConfigValidationOrContinueToModelValidation(ActionListener listener) { + logger.info("Skipping indexing detector. No blocking issue found so far."); + if (!getValidationTypes(validationType).contains(ValidationAspect.MODEL)) { + listener.onResponse(null); + } else { + validateModel(listener); + } + } + + @SuppressWarnings("unchecked") + protected void indexConfig(String id, ActionListener listener) throws IOException { + Config copiedConfig = copyConfig(user, config); + IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) + .setRefreshPolicy(refreshPolicy) + .source(copiedConfig.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .timeout(requestTimeout); + if (StringUtils.isNotBlank(id)) { + indexRequest.id(id); + } + + client.index(indexRequest, new ActionListener() { + @Override + public void onResponse(IndexResponse indexResponse) { + String errorMsg = checkShardsFailure(indexResponse); + if (errorMsg != null) { + listener.onFailure(new OpenSearchStatusException(errorMsg, indexResponse.status())); + return; + } + listener.onResponse(createIndexConfigResponse(indexResponse, copiedConfig)); + } + + @Override + public void onFailure(Exception e) { + logger.warn("Failed to update config", e); + if (e.getMessage() != null && e.getMessage().contains("version conflict")) { + listener.onFailure(new IllegalArgumentException("There was a problem updating the config:[" + id + "]")); + } else { + listener.onFailure(e); + } + } + }); + } + + protected void onCreateMappingsResponse(CreateIndexResponse response, boolean indexingDryRun, ActionListener listener) { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + prepareConfigIndexing(indexingDryRun, listener); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + listener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + } + + protected String checkShardsFailure(IndexResponse response) { + StringBuilder failureReasons = new StringBuilder(); + if (response.getShardInfo().getFailed() > 0) { + for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { + failureReasons.append(failure); + } + return failureReasons.toString(); + } + return null; + } + + /** + * Validate config/syntax, and runtime error of config features + * @param id config id + * @param indexingDryRun if false, then will eventually index detector; true, skip indexing detector + * @throws IOException when fail to parse feature aggregation + */ + // TODO: move this method to util class so that it can be re-usable for more use cases + // https://github.com/opensearch-project/anomaly-detection/issues/39 + protected void validateConfigFeatures(String id, boolean indexingDryRun, ActionListener listener) throws IOException { + if (config != null && (config.getFeatureAttributes() == null || config.getFeatureAttributes().isEmpty())) { + checkConfigNameExists(id, indexingDryRun, listener); + return; + } + // checking configuration/syntax error of detector features + String error = RestHandlerUtils.checkFeaturesSyntax(config, maxFeatures); + if (StringUtils.isNotBlank(error)) { + if (indexingDryRun) { + listener.onFailure(createValidationException(error, ValidationIssueType.FEATURE_ATTRIBUTES)); + return; + } + listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); + return; + } + // checking runtime error from feature query + ActionListener>> validateFeatureQueriesListener = ActionListener.wrap(response -> { + checkConfigNameExists(id, indexingDryRun, listener); + }, exception -> { listener.onFailure(createValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES)); }); + MultiResponsesDelegateActionListener>> multiFeatureQueriesResponseListener = + new MultiResponsesDelegateActionListener>>( + validateFeatureQueriesListener, + config.getFeatureAttributes().size(), + getFeatureErrorMsg(config.getName()), + false + ); + + for (Feature feature : config.getFeatureAttributes()) { + SearchSourceBuilder ssb = new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery()); + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(ssb); + ActionListener searchResponseListener = ActionListener.wrap(response -> { + Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); + if (aggFeatureResult.isPresent()) { + multiFeatureQueriesResponseListener + .onResponse( + new MergeableList>(new ArrayList>(Arrays.asList(aggFeatureResult))) + ); + } else { + String errorMessage = CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + feature.getName(); + logger.error(errorMessage); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + } + }, e -> { + String errorMessage; + if (isExceptionCausedByInvalidQuery(e)) { + errorMessage = CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + feature.getName(); + } else { + errorMessage = CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG + feature.getName(); + } + logger.error(errorMessage, e); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); + }); + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, context, searchResponseListener); + } + } + + protected Integer getMaxSingleStreamConfigs() { + return maxSingleStreamConfigs; + } + + protected Integer getMaxHCConfigs() { + return maxHCConfigs; + } + + protected abstract TimeSeriesException createValidationException(String msg, ValidationIssueType type); + + protected abstract Config parse(XContentParser parser, GetResponse response) throws IOException; + + protected abstract String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs); + + protected abstract String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs); + + protected abstract String getNoDocsInUserIndexErrorMsg(String suppliedIndices); + + protected abstract String getDuplicateConfigErrorMsg(String nane); + + protected abstract String getFeatureErrorMsg(String id); + + protected abstract Config copyConfig(User user, Config config); + + protected abstract T createIndexConfigResponse(IndexResponse indexResponse, Config config); + + protected abstract Set getDefaultValidationType(); + + /** + * Validate model + * @param listener listener to return response + */ + protected abstract void validateModel(ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java b/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java new file mode 100644 index 000000000..8e676113c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +/** + * Get job to make sure job has been stopped before updating a config. + */ +public class ConfigUpdateConfirmer & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager> { + + private final Logger logger = LogManager.getLogger(ConfigUpdateConfirmer.class); + + private final TaskManagerType taskManager; + private final TransportService transportService; + + public ConfigUpdateConfirmer(TaskManagerType taskManager, TransportService transportService) { + this.taskManager = taskManager; + this.transportService = transportService; + } + + /** + * Get job for update/delete config. + * If job exist, will return error message; otherwise, execute function. + * + * @param clusterService OS cluster service + * @param client OS node client + * @param id job identifier + * @param listener Listener to send response + * @param function time series function + * @param xContentRegistry Registry which is used for XContentParser + */ + public void confirmJobRunning( + ClusterService clusterService, + Client client, + String id, + ActionListener listener, + ExecutorFunction function, + NamedXContentRegistry xContentRegistry + ) { + // forecasting and ad share the same job index + if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { + GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(id); + client + .get( + request, + ActionListener.wrap(response -> onGetJobResponseForWrite(response, listener, function, xContentRegistry), exception -> { + logger.error("Fail to get job: " + id, exception); + listener.onFailure(exception); + }) + ); + } else { + function.execute(); + } + } + + private void onGetJobResponseForWrite( + GetResponse response, + ActionListener listener, + ExecutorFunction function, + NamedXContentRegistry xContentRegistry + ) { + if (response.isExists()) { + String jobId = response.getId(); + if (jobId != null) { + // check if job is running, if yes, we can't delete the config + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job adJob = Job.parse(parser); + if (adJob.isEnabled()) { + listener.onFailure(new OpenSearchStatusException("Job is running: " + jobId, RestStatus.BAD_REQUEST)); + return; + } + } catch (IOException e) { + String message = "Failed to parse job " + jobId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.BAD_REQUEST)); + } + } + } + function.execute(); + } + + /** + * Confirm if any historical or run once is running. If there is still any left over tasks running, + * listener returns failure complaining task running. Otherwise, listener response returns null + * (indicating no batch running). + * @param configId Config id + * @param tasks tasks to check. + * @param listener to return response or failure. + */ + public void confirmBatchRunning(String configId, List tasks, ActionListener listener) { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, tasks, (task) -> { + if (task.isPresent() && !task.get().isDone()) { + // can't update config if there is task running + listener.onFailure(new OpenSearchStatusException("Run once or historical is running", RestStatus.BAD_REQUEST)); + } else { + listener.onResponse(null); + } + }, transportService, false, listener); // false means don't reset task state as inactive/stopped state + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/HistogramAggregationHelper.java b/src/main/java/org/opensearch/timeseries/rest/handler/HistogramAggregationHelper.java new file mode 100644 index 000000000..eeaff0dc1 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/HistogramAggregationHelper.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import java.time.Duration; +import java.time.ZonedDateTime; +import java.util.List; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.Histogram.Bucket; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +/** + * the class provides helper methods specifically for histogram aggregations + * + */ +public class HistogramAggregationHelper { + protected static final Logger logger = LogManager.getLogger(HistogramAggregationHelper.class); + + protected static final String AGGREGATION = "agg"; + + private Config config; + private final TimeValue requestTimeout; + + public HistogramAggregationHelper(Config config, TimeValue requestTimeout) { + this.config = config; + this.requestTimeout = requestTimeout; + } + + public Histogram checkBucketResultErrors(SearchResponse response) { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date with + // the large amounts of changes there). For this reason I'm not throwing a SearchException but instead a validation exception + // which will be converted to validation response. + logger.warn("Unexpected null aggregation."); + throw new ValidationException( + CommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, + ValidationIssueType.AGGREGATION, + ValidationAspect.MODEL + ); + } + Histogram aggregate = aggs.get(AGGREGATION); + if (aggregate == null) { + throw new IllegalArgumentException("Failed to find valid aggregation result"); + } + return aggregate; + } + + public AggregationBuilder getBucketAggregation(int intervalInMinutes, LongBounds timeStampBound) { + return AggregationBuilders + .dateHistogram(AGGREGATION) + .field(config.getTimeField()) + .minDocCount(1) + .hardBounds(timeStampBound) + .fixedInterval(DateHistogramInterval.minutes(intervalInMinutes)); + } + + public Long timeConfigToMilliSec(TimeConfiguration timeConfig) { + return Optional.ofNullable((IntervalTimeConfiguration) timeConfig).map(t -> t.toDuration().toMillis()).orElse(0L); + } + + public LongBounds getTimeRangeBounds(long endMillis, long intervalInMillis) { + Long startMillis = endMillis - (getNumberOfSamples(intervalInMillis) * intervalInMillis); + return new LongBounds(startMillis, endMillis); + } + + public int getNumberOfSamples(long intervalInMillis) { + return Math + .max( + (int) (Duration.ofHours(TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS).toMillis() / intervalInMillis), + TimeSeriesSettings.MIN_TRAIN_SAMPLES + ); + } + + /** + * @param histogram buckets returned via Date historgram aggregation + * @param intervalInMillis suggested interval to use + * @return the number of buckets having data + */ + public double processBucketAggregationResults(Histogram histogram, long intervalInMillis, Config config) { + // In all cases, when the specified end time does not exist, the actual end time is the closest available time after the specified + // end. + // so we only have non-empty buckets + List bucketsInResponse = histogram.getBuckets(); + if (bucketsInResponse.size() >= config.getShingleSize() + TimeSeriesSettings.NUM_MIN_SAMPLES) { + long minTimestampMillis = convertKeyToEpochMillis(bucketsInResponse.get(0).getKey()); + long maxTimestampMillis = convertKeyToEpochMillis(bucketsInResponse.get(bucketsInResponse.size() - 1).getKey()); + double totalBuckets = (maxTimestampMillis - minTimestampMillis) / intervalInMillis; + return histogram.getBuckets().size() / totalBuckets; + } + return 0; + } + + public SearchSourceBuilder getSearchSourceBuilder(QueryBuilder query, AggregationBuilder aggregation) { + return new SearchSourceBuilder().query(query).aggregation(aggregation).size(0).timeout(requestTimeout); + } + + public static long convertKeyToEpochMillis(Object key) { + return key instanceof ZonedDateTime ? ((ZonedDateTime) key).toInstant().toEpochMilli() + : key instanceof Double ? ((Double) key).longValue() + : key instanceof Long ? (Long) key + : -1L; + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java new file mode 100644 index 000000000..92cb6ad65 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java @@ -0,0 +1,594 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.base.Throwables; + +/** + * job REST action handler to process POST/PUT request. + */ +public abstract class IndexJobActionHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder> { + + private final IndexManagementType indexManagement; + private final Client client; + private final NamedXContentRegistry xContentRegistry; + protected final TaskManagerType taskManager; + + private final Logger logger = LogManager.getLogger(IndexJobActionHandler.class); + private final TimeValue requestTimeout; + private final ExecuteResultResponseRecorderType recorder; + private final ActionType> resultAction; + private final AnalysisType analysisType; + private final String stateIndex; + private final ActionType stopConfigAction; + protected final NodeStateManager nodeStateManager; + + /** + * Constructor function. + * + * @param client ES node client that executes actions on the local node + * @param indexManagement index manager + * @param xContentRegistry Registry which is used for XContentParser + * @param taskManager task manager + * @param recorder Utility to record AnomalyResultAction execution result + * @param resultAction result action + * @param analysisType analysis type + * @param stateIndex State index name + * @param stopConfigAction Stop config action + * @param nodeStateManager Node state manager + * @param settings Node settings + * @param timeoutSetting timeout setting + */ + public IndexJobActionHandler( + Client client, + IndexManagementType indexManagement, + NamedXContentRegistry xContentRegistry, + TaskManagerType taskManager, + ExecuteResultResponseRecorderType recorder, + ActionType> resultAction, + AnalysisType analysisType, + String stateIndex, + ActionType stopConfigAction, + NodeStateManager nodeStateManager, + Settings settings, + Setting timeoutSetting + ) { + this.client = client; + this.indexManagement = indexManagement; + this.xContentRegistry = xContentRegistry; + this.taskManager = taskManager; + this.recorder = recorder; + this.resultAction = resultAction; + this.analysisType = analysisType; + this.stateIndex = stateIndex; + this.stopConfigAction = stopConfigAction; + this.nodeStateManager = nodeStateManager; + this.requestTimeout = timeoutSetting.get(settings); + } + + /** + * Start job. + * 1. If job doesn't exist, create new job. + * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. + * @param config config accessor + * @param listener Listener to send responses + */ + public void startJob(Config config, TransportService transportService, ActionListener listener) { + // this start listener is created & injected throughout the job handler so that whenever the job response is received, + // there's the extra step of trying to index results and update detector state with a 60s delay. + ActionListener startListener = ActionListener.wrap(r -> { + try { + Instant executionEndTime = Instant.now(); + IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) config.getInterval(); + Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); + ResultRequest getRequest = createResultRequest( + config.getId(), + executionStartTime.toEpochMilli(), + executionEndTime.toEpochMilli() + ); + client + .execute( + resultAction, + getRequest, + ActionListener + .wrap(response -> recorder.indexResult(executionStartTime, executionEndTime, response, config), exception -> { + + recorder + .indexResultException( + executionStartTime, + executionEndTime, + Throwables.getStackTraceAsString(exception), + null, + config + ); + }) + ); + } catch (Exception ex) { + listener.onFailure(ex); + return; + } + listener.onResponse(r); + + }, listener::onFailure); + if (!indexManagement.doesJobIndexExist()) { + indexManagement.initJobIndex(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + createJob(config, transportService, startListener); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + startListener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + " with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + }, exception -> startListener.onFailure(exception))); + } else { + createJob(config, transportService, startListener); + } + } + + private void createJob(Config config, TransportService transportService, ActionListener listener) { + try { + IntervalTimeConfiguration interval = (IntervalTimeConfiguration) config.getInterval(); + Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); + Duration duration = Duration.of(interval.getInterval(), interval.getUnit()); + + Job job = new Job( + config.getId(), + schedule, + config.getWindowDelay(), + true, + Instant.now(), + null, + Instant.now(), + duration.getSeconds(), + config.getUser(), + config.getCustomResultIndex(), + analysisType + ); + + getJobForWrite(config, job, transportService, listener); + } catch (Exception e) { + String message = "Failed to parse job " + config.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + private void getJobForWrite(Config config, Job job, TransportService transportService, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(config.getId()); + + client + .get( + getRequest, + ActionListener + .wrap( + response -> onGetJobForWrite(response, config, job, transportService, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetJobForWrite( + GetResponse response, + Config config, + Job job, + TransportService transportService, + ActionListener listener + ) throws IOException { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job currentAdJob = Job.parse(parser); + if (currentAdJob.isEnabled()) { + listener + .onFailure( + new OpenSearchStatusException("Anomaly detector job is already running: " + config.getId(), RestStatus.OK) + ); + return; + } else { + Job newJob = new Job( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + job.isEnabled(), + Instant.now(), + currentAdJob.getDisabledTime(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex(), + job.getAnalysisType() + ); + // Get latest realtime task and check its state before index job. Will reset running realtime task + // as STOPPED first if job disabled, then start new job and create new realtime task. + startConfig( + config, + null, + job.getUser(), + transportService, + ActionListener.wrap(r -> { indexJob(newJob, null, listener); }, e -> { + // Have logged error message in ADTaskManager#startDetector + listener.onFailure(e); + }) + ); + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector job " + job.getName(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + startConfig( + config, + null, + job.getUser(), + transportService, + ActionListener.wrap(r -> { indexJob(job, null, listener); }, e -> listener.onFailure(e)) + ); + } + } + + /** + * Start config. + * For historical analysis, this method will be called on coordinating node. + * For realtime task, we won't know AD job coordinating node until AD job starts. So + * this method will be called on vanilla node. + * + * Will init task index if not exist and write new AD task to index. If task index + * exists, will check if there is task running. If no running task, reset old task + * as not latest and clean old tasks which exceeds max old task doc limitation. + * Then find out node with least load and dispatch task to that node(worker node). + * + * @param config anomaly detector + * @param dateRange detection date range + * @param user user + * @param transportService transport service + * @param listener action listener + */ + public void startConfig( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + try { + if (indexManagement.doesStateIndexExist()) { + // If state index exist, check if latest AD task is running + taskManager.getAndExecuteOnLatestConfigLevelTask(config, dateRange, false, user, transportService, listener); + } else { + // If state index doesn't exist, create index and execute detector. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", stateIndex); + taskManager.updateLatestFlagOfOldTasksAndCreateNewTask(config, dateRange, false, user, TaskState.CREATED, listener); + } else { + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, stateIndex); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + taskManager.updateLatestFlagOfOldTasksAndCreateNewTask(config, dateRange, false, user, TaskState.CREATED, listener); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } catch (Exception e) { + logger.error("Failed to start detector " + config.getId(), e); + listener.onFailure(e); + } + } + + private void indexJob(Job job, ExecutorFunction function, ActionListener listener) throws IOException { + IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) + .timeout(requestTimeout) + .id(job.getName()); + client + .index( + indexRequest, + ActionListener + .wrap( + response -> onIndexAnomalyDetectorJobResponse(response, function, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onIndexAnomalyDetectorJobResponse( + IndexResponse response, + ExecutorFunction function, + ActionListener listener + ) { + if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { + String errorMsg = ExceptionUtil.getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + if (function != null) { + function.execute(); + } else { + JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); + listener.onResponse(anomalyDetectorJobResponse); + } + } + + /** + * Stop config job. + * 1.If job not exists, return error message + * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. + * + * @param configId config identifier + * @param listener Listener to send responses + */ + public void stopJob(String configId, TransportService transportService, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(configId); + + client.get(getRequest, ActionListener.wrap(response -> { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + if (!job.isEnabled()) { + taskManager.stopLatestRealtimeTask(configId, TaskState.STOPPED, null, transportService, listener); + } else { + Job newJob = new Job( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + false, // disable job + job.getEnabledTime(), + Instant.now(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex(), + job.getAnalysisType() + ); + indexJob( + newJob, + () -> client + .execute( + stopConfigAction, + new StopConfigRequest(configId), + stopConfigListener(configId, transportService, listener) + ), + listener + ); + } + } catch (IOException e) { + String message = "Failed to parse job " + configId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + listener.onResponse(new JobResponse(configId)); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + listener.onResponse(new JobResponse(configId)); + } else { + listener.onFailure(exception); + } + })); + } + + private ActionListener stopConfigListener( + String configId, + TransportService transportService, + ActionListener listener + ) { + return new ActionListener() { + @Override + public void onResponse(StopConfigResponse stopDetectorResponse) { + if (stopDetectorResponse.success()) { + logger.info("model deleted successfully for config {}", configId); + // e.g., StopDetectorTransportAction will send out DeleteModelAction which will clear all realtime cache. + // Pass null transport service to method "stopLatestRealtimeTask" to not re-clear coordinating node cache. + taskManager.stopLatestRealtimeTask(configId, TaskState.STOPPED, null, null, listener); + } else { + logger.error("Failed to delete model for config {}", configId); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + taskManager + .stopLatestRealtimeTask( + configId, + TaskState.FAILED, + new OpenSearchStatusException("Failed to delete model", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete model for config " + configId, e); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + taskManager + .stopLatestRealtimeTask( + configId, + TaskState.FAILED, + new OpenSearchStatusException("Failed to execute stop config action", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + }; + } + + /** + * Start config. Will create schedule job for realtime analysis, + * and start task for historical/run once. + * + * @param configId config id + * @param dateRange historical analysis date range + * @param user user + * @param transportService transport service + * @param context thread context + * @param listener action listener + */ + public void startConfig( + String configId, + DateRange dateRange, + User user, + TransportService transportService, + ThreadContext.StoredContext context, + ActionListener listener + ) { + // upgrade index mapping + indexManagement.update(); + + nodeStateManager.getConfig(configId, analysisType, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + + // Validate if config is ready to start. Will return null if ready to start. + String errorMessage = validateConfig(config.get()); + if (errorMessage != null) { + listener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + return; + } + String resultIndex = config.get().getCustomResultIndex(); + if (resultIndex == null) { + startRealtimeOrHistoricalAnalysis(dateRange, user, transportService, listener, config); + return; + } + context.restore(); + indexManagement + .initCustomResultIndexAndExecute( + resultIndex, + () -> startRealtimeOrHistoricalAnalysis(dateRange, user, transportService, listener, config), + listener + ); + + }, listener); + } + + private String validateConfig(Config detector) { + String error = null; + if (detector.getFeatureAttributes().size() == 0) { + error = "Can't start job as no features configured"; + } else if (detector.getEnabledFeatureIds().size() == 0) { + error = "Can't start job as no enabled features configured"; + } + return error; + } + + private void startRealtimeOrHistoricalAnalysis( + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener, + Optional config + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (dateRange == null) { + // start realtime job + startJob(config.get(), transportService, listener); + } else { + // start historical analysis task + taskManager.startHistorical(config.get(), dateRange, user, transportService, listener); + } + } catch (Exception e) { + logger.error("Failed to stash context", e); + listener.onFailure(e); + } + } + + protected abstract ResultRequest createResultRequest(String configID, long start, long end); + + protected abstract List getBatchConfigTaskTypes(); + + public abstract void stopConfig( + String configId, + boolean historical, + User user, + TransportService transportService, + ActionListener listener + ); +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java new file mode 100644 index 000000000..e3c8a403e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IntervalCalculation.java @@ -0,0 +1,431 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import java.io.IOException; +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class IntervalCalculation { + private final Logger logger = LogManager.getLogger(IntervalCalculation.class); + + private final Config config; + private final TimeValue requestTimeout; + private final HistogramAggregationHelper histogramAggHelper; + private final Client client; + private final SecurityClientUtil clientUtil; + private final User user; + private final AnalysisType context; + private final Clock clock; + private final FullBucketRatePredicate acceptanceCriteria; + + public IntervalCalculation( + Config config, + TimeValue requestTimeout, + Client client, + SecurityClientUtil clientUtil, + User user, + AnalysisType context, + Clock clock + ) { + this.config = config; + this.requestTimeout = requestTimeout; + this.histogramAggHelper = new HistogramAggregationHelper(config, requestTimeout); + this.client = client; + this.clientUtil = clientUtil; + this.user = user; + this.context = context; + this.clock = clock; + this.acceptanceCriteria = new FullBucketRatePredicate(); + + } + + public void findInterval(long latestTime, Map topEntity, ActionListener listener) { + ActionListener> minimumIntervalListener = ActionListener.wrap(minIntervalAndValidity -> { + if (minIntervalAndValidity.getRight()) { + // the minimum interval is also the interval passing acceptance criteria and we can return immediately + listener.onResponse(minIntervalAndValidity.getLeft()); + } else if (minIntervalAndValidity.getLeft() == null) { + // the minimum interval is too large + listener.onResponse(null); + } else { + // starting exploring larger interval + getBucketAggregates(latestTime, topEntity, minIntervalAndValidity.getLeft(), listener); + } + }, listener::onFailure); + // we use 1 minute = 60000 milliseconds to find minimum interval + LongBounds longBounds = histogramAggHelper.getTimeRangeBounds(latestTime, 60000); + findMinimumInterval(topEntity, longBounds, minimumIntervalListener); + } + + private void getBucketAggregates( + long latestTime, + Map topEntity, + IntervalTimeConfiguration minimumInterval, + ActionListener listener + ) throws IOException { + + try { + int newIntervalInMinutes = increaseAndGetNewInterval(minimumInterval); + LongBounds timeStampBounds = histogramAggHelper.getTimeRangeBounds(latestTime, newIntervalInMinutes); + SearchRequest searchRequest = composeIntervalQuery(topEntity, newIntervalInMinutes, timeStampBounds); + ActionListener intervalListener = ActionListener + .wrap(interval -> listener.onResponse(interval), exception -> { + listener.onFailure(exception); + logger.error("Failed to get interval recommendation", exception); + }); + final ActionListener searchResponseListener = new IntervalRecommendationListener( + intervalListener, + searchRequest.source(), + (IntervalTimeConfiguration) config.getInterval(), + clock.millis() + TimeSeriesSettings.TOP_VALIDATE_TIMEOUT_IN_MILLIS, + latestTime, + timeStampBounds + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } catch (ValidationException ex) { + listener.onFailure(ex); + } + } + + /** + * + * @param oldInterval + * @return new interval in minutes + */ + private int increaseAndGetNewInterval(IntervalTimeConfiguration oldInterval) { + return (int) Math + .ceil( + IntervalTimeConfiguration.getIntervalInMinute(oldInterval) + * TimeSeriesSettings.INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER + ); + } + + /** + * ActionListener class to handle execution of multiple bucket aggregations one after the other + * Bucket aggregation with different interval lengths are executed one by one to check if the data is dense enough + * We only need to execute the next query if the previous one led to data that is too sparse. + */ + class IntervalRecommendationListener implements ActionListener { + private final ActionListener intervalListener; + SearchSourceBuilder searchSourceBuilder; + IntervalTimeConfiguration currentIntervalToTry; + private final long expirationEpochMs; + private final long latestTime; + private LongBounds currentTimeStampBounds; + + IntervalRecommendationListener( + ActionListener intervalListener, + SearchSourceBuilder searchSourceBuilder, + IntervalTimeConfiguration currentIntervalToTry, + long expirationEpochMs, + long latestTime, + LongBounds timeStampBounds + ) { + this.intervalListener = intervalListener; + this.searchSourceBuilder = searchSourceBuilder; + this.currentIntervalToTry = currentIntervalToTry; + this.expirationEpochMs = expirationEpochMs; + this.latestTime = latestTime; + this.currentTimeStampBounds = timeStampBounds; + } + + @Override + public void onResponse(SearchResponse response) { + try { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + intervalListener.onFailure(e); + } + + if (aggregate == null) { + intervalListener.onResponse(null); + return; + } + + int newIntervalMinute = increaseAndGetNewInterval(currentIntervalToTry); + double fullBucketRate = histogramAggHelper.processBucketAggregationResults(aggregate, newIntervalMinute * 60000, config); + // If rate is above success minimum then return interval suggestion. + if (fullBucketRate > TimeSeriesSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { + intervalListener.onResponse(this.currentIntervalToTry); + } else if (expirationEpochMs < clock.millis()) { + intervalListener + .onFailure( + new ValidationException( + CommonMessages.TIMEOUT_ON_INTERVAL_REC, + ValidationIssueType.TIMEOUT, + ValidationAspect.MODEL + ) + ); + logger.info(CommonMessages.TIMEOUT_ON_INTERVAL_REC); + // keep trying higher intervals as new interval is below max, and we aren't decreasing yet + } else if (newIntervalMinute < TimeSeriesSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES) { + searchWithDifferentInterval(newIntervalMinute); + // The below block is executed only the first time when new interval is above max and + // we aren't decreasing yet, at this point we will start decreasing for the first time + // if we are inside the below block + } else { + // newIntervalMinute >= MAX_INTERVAL_REC_LENGTH_IN_MINUTES + intervalListener.onResponse(null); + } + + } catch (Exception e) { + onFailure(e); + } + } + + private void searchWithDifferentInterval(int newIntervalMinuteValue) { + this.currentIntervalToTry = new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES); + this.currentTimeStampBounds = histogramAggHelper.getTimeRangeBounds(latestTime, newIntervalMinuteValue); + // Searching again using an updated interval + SearchSourceBuilder updatedSearchSourceBuilder = histogramAggHelper + .getSearchSourceBuilder( + searchSourceBuilder.query(), + histogramAggHelper.getBucketAggregation(newIntervalMinuteValue, currentTimeStampBounds) + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(updatedSearchSourceBuilder), + client::search, + user, + client, + context, + this + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to recommend new interval", e); + intervalListener + .onFailure( + new ValidationException( + CommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, + ValidationIssueType.AGGREGATION, + ValidationAspect.MODEL + ) + ); + } + } + + /** + * This method calculates median timestamp difference as minimum interval. + * + * + * Using the median timestamp difference as a minimum sampling interval is a heuristic approach + * that can be beneficial in specific contexts, especially when dealing with irregularly spaced data. + * + * Advantages: + * 1. Robustness: The median is less sensitive to outliers compared to the mean. This makes it a + * more stable metric in the presence of irregular data points or anomalies. + * 2. Reflects Typical Intervals: The median provides a measure of the "typical" interval between + * data points, which can be useful when there are varying intervals. + * + * Disadvantages: + * 1. Not Standard in Signal Processing: Traditional signal processing often relies on fixed + * sampling rates determined by the Nyquist-Shannon sampling theorem. The median-based approach + * is more of a data-driven heuristic. + * 2. May Not Capture All Features: Depending on the nature of the data, using the median interval + * might miss some rapid events or features in the data. + * + * In summary, while not a standard practice, using the median timestamp difference as a sampling + * interval can be a practical approach in scenarios where data arrival is irregular and there's + * a need to balance between capturing data features and avoiding over-sampling. + * + * @param topEntity top entity to use + * @param timeStampBounds Used to determine start and end date range to search for data + * @param listener returns minimum interval and whether the interval passes data density test + */ + private void findMinimumInterval( + Map topEntity, + LongBounds timeStampBounds, + ActionListener> listener + ) { + try { + SearchRequest searchRequest = composeIntervalQuery(topEntity, 1, timeStampBounds); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + // fail to find the minimum interval. Return one minute. + logger.warn("Fail to get aggregated result"); + listener.onResponse(Pair.of(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), Boolean.FALSE)); + return; + } + // In all cases, when the specified end time does not exist, the actual end time is the closest available time after the + // specified end. + // so we only have non-empty buckets + // in the original order, buckets are sorted in the ascending order of timestamps. + // Since the stream processing preserves the order of elements, we don't need to sort timestamps again. + List timestamps = aggregate + .getBuckets() + .stream() + .map(entry -> HistogramAggregationHelper.convertKeyToEpochMillis(entry.getKey())) + .collect(Collectors.toList()); + + if (timestamps.isEmpty()) { + logger.warn("empty data, return one minute by default"); + listener.onResponse(Pair.of(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), Boolean.FALSE)); + return; + } + + double medianDifference = calculateMedianDifference(timestamps); + long minimumMinutes = millisecondsToCeilMinutes(((Double) medianDifference).longValue()); + if (minimumMinutes > TimeSeriesSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES) { + logger.warn("The minimum interval is too large: {}", minimumMinutes); + listener.onResponse(Pair.of(null, false)); + return; + } + listener + .onResponse( + Pair + .of( + new IntervalTimeConfiguration(minimumMinutes, ChronoUnit.MINUTES), + acceptanceCriteria.test(aggregate, minimumMinutes) + ) + ); + }, listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private static double calculateMedianDifference(List timestamps) { + List differences = new ArrayList<>(); + + for (int i = 1; i < timestamps.size(); i++) { + differences.add(timestamps.get(i) - timestamps.get(i - 1)); + } + + Collections.sort(differences); + + int middle = differences.size() / 2; + if (differences.size() % 2 == 0) { + // If even number of differences, return the average of the two middle values + return (differences.get(middle - 1) + differences.get(middle)) / 2.0; + } else { + // If odd number of differences, return the middle value + return differences.get(middle); + } + } + + /** + * Convert a duration in milliseconds to the nearest minute value that is greater than + * or equal to the given duration. + * + * For example, a duration of 123456 milliseconds is slightly more than 2 minutes. + * So, it gets rounded up and the method returns 3. + * + * @param milliseconds The duration in milliseconds. + * @return The rounded up value in minutes. + */ + private static long millisecondsToCeilMinutes(long milliseconds) { + // Since there are 60,000 milliseconds in a minute, we divide by 60,000 to get + // the number of complete minutes. We add 59,999 before division to ensure + // that any duration that exceeds a whole minute but is less than the next + // whole minute is rounded up to the next minute. + return (milliseconds + 59999) / 60000; + } + + private SearchRequest composeIntervalQuery(Map topEntity, int intervalInMinutes, LongBounds timeStampBounds) { + AggregationBuilder aggregation = histogramAggHelper.getBucketAggregation(intervalInMinutes, timeStampBounds); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + if (config.isHighCardinality()) { + if (topEntity.isEmpty()) { + throw new ValidationException( + CommonMessages.CATEGORY_FIELD_TOO_SPARSE, + ValidationIssueType.CATEGORY, + ValidationAspect.MODEL + ); + } + for (Map.Entry entry : topEntity.entrySet()) { + query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); + } + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(query) + .aggregation(aggregation) + .size(0) + .timeout(requestTimeout); + return new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + } + + interface HistogramPredicate { + boolean test(Histogram histogram, long minimumMinutes); + } + + class FullBucketRatePredicate implements HistogramPredicate { + + @Override + public boolean test(Histogram histogram, long minimumMinutes) { + double fullBucketRate = histogramAggHelper.processBucketAggregationResults(histogram, minimumMinutes * 60000, config); + // If rate is above success minimum then return true. + return fullBucketRate > TimeSeriesSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE; + } + + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/LatestTimeRetriever.java b/src/main/java/org/opensearch/timeseries/rest/handler/LatestTimeRetriever.java new file mode 100644 index 000000000..5d0393842 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/LatestTimeRetriever.java @@ -0,0 +1,186 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import java.time.Instant; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.SecurityClientUtil; + +public class LatestTimeRetriever { + public static final Logger logger = LogManager.getLogger(LatestTimeRetriever.class); + + protected static final String AGG_NAME_TOP = "top_agg"; + + private final Config config; + // private final ActionListener> listener; + private final HistogramAggregationHelper histogramAggHelper; + private final SecurityClientUtil clientUtil; + private final Client client; + private final User user; + private final AnalysisType context; + private final SearchFeatureDao searchFeatureDao; + + public LatestTimeRetriever( + Config config, + TimeValue requestTimeout, + SecurityClientUtil clientUtil, + Client client, + User user, + AnalysisType context, + SearchFeatureDao searchFeatureDao + ) { + this.config = config; + this.histogramAggHelper = new HistogramAggregationHelper(config, requestTimeout); + this.clientUtil = clientUtil; + this.client = client; + this.user = user; + this.context = context; + this.searchFeatureDao = searchFeatureDao; + } + + /** + * Need to first check if HC analysis or not before retrieving latest date time. + * If the config is HC then we will find the top entity and treat as single stream for + * validation purposes + * @param listener to return latest time and entity attributes if the config is HC + */ + public void checkIfHC(ActionListener, Map>> listener) { + ActionListener> topEntityListener = ActionListener + .wrap( + topEntity -> searchFeatureDao + .getLatestDataTime( + config, + Optional.of(Entity.createEntityByReordering(topEntity)), + context, + ActionListener.wrap(latestTime -> listener.onResponse(Pair.of(latestTime, topEntity)), listener::onFailure) + ), + exception -> { + listener.onFailure(exception); + logger.error("Failed to get top entity for categorical field", exception); + } + ); + if (config.isHighCardinality()) { + getTopEntity(topEntityListener); + } else { + topEntityListener.onResponse(Collections.emptyMap()); + } + } + + // For single category HCs, this method uses bucket aggregation and sort to get the category field + // that have the highest document count in order to use that top entity for further validation + // For multi-category HCs we use a composite aggregation to find the top fields for the entity + // with the highest doc count. + public void getTopEntity(ActionListener> topEntityListener) { + // Look at data back to the lower bound given the max interval we recommend or one given + long maxIntervalInMinutes = Math.max(TimeSeriesSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES, config.getIntervalInMinutes()); + LongBounds timeRangeBounds = histogramAggHelper.getTimeRangeBounds(Instant.now().toEpochMilli(), maxIntervalInMinutes * 60000); + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) + .from(timeRangeBounds.getMin()) + .to(timeRangeBounds.getMax()); + AggregationBuilder bucketAggs; + Map topKeys = new HashMap<>(); + if (config.getCategoryFields().size() == 1) { + bucketAggs = AggregationBuilders.terms(AGG_NAME_TOP).field(config.getCategoryFields().get(0)).order(BucketOrder.count(true)); + } else { + bucketAggs = AggregationBuilders + .composite( + AGG_NAME_TOP, + config.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + ) + .size(1000) + .subAggregation( + PipelineAggregatorBuilders + .bucketSort("bucketSort", Collections.singletonList(new FieldSortBuilder("_count").order(SortOrder.DESC))) + .size(1) + ); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(rangeQuery) + .aggregation(bucketAggs) + .trackTotalHits(false) + .size(0); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + topEntityListener.onResponse(Collections.emptyMap()); + return; + } + if (config.getCategoryFields().size() == 1) { + Terms entities = aggs.get(AGG_NAME_TOP); + Object key = entities + .getBuckets() + .stream() + .max(Comparator.comparingInt(entry -> (int) entry.getDocCount())) + .map(MultiBucketsAggregation.Bucket::getKeyAsString) + .orElse(null); + topKeys.put(config.getCategoryFields().get(0), key); + } else { + CompositeAggregation compositeAgg = aggs.get(AGG_NAME_TOP); + topKeys + .putAll( + compositeAgg + .getBuckets() + .stream() + .flatMap(bucket -> bucket.getKey().entrySet().stream()) // this would create a flattened stream of map entries + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())) + ); + } + for (Map.Entry entry : topKeys.entrySet()) { + if (entry.getValue() == null) { + topEntityListener.onResponse(Collections.emptyMap()); + return; + } + } + topEntityListener.onResponse(topKeys); + }, topEntityListener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/ModelValidationActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/ModelValidationActionHandler.java new file mode 100644 index 000000000..7e49f3ead --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/ModelValidationActionHandler.java @@ -0,0 +1,482 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CONFIG_BUCKET_MINIMUM_SUCCESS_RATE; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.MergeableList; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.ValidateConfigResponse; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; + +/** + *

This class executes all validation checks that are not blocking on the 'model' level. + * This mostly involves checking if the data is generally dense enough to complete model training + * which is based on if enough buckets in the last x intervals have at least 1 document present.

+ *

Initially different bucket aggregations are executed with with every configuration applied and with + * different varying intervals in order to find the best interval for the data. If no interval is found with all + * configuration applied then each configuration is tested sequentially for sparsity

+ */ +// TODO: Add more UT and IT +public class ModelValidationActionHandler { + + protected final Config config; + protected final ClusterService clusterService; + protected final Logger logger = LogManager.getLogger(ModelValidationActionHandler.class); + protected final TimeValue requestTimeout; + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final NamedXContentRegistry xContentRegistry; + protected final ActionListener listener; + protected final Clock clock; + protected final String validationType; + protected final Settings settings; + protected final User user; + protected final AnalysisType context; + private final HistogramAggregationHelper histogramAggHelper; + private final IntervalCalculation intervalCalculation; + // time range bounds to verify configured interval makes sense or not + private LongBounds timeRangeToSearchForConfiguredInterval; + private final LatestTimeRetriever latestTimeRetriever; + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client OS node client that executes actions on the local node + * @param clientUtil client util + * @param listener OS channel used to construct bytes / builder based outputs, and send responses + * @param config config instance + * @param requestTimeout request time out configuration + * @param xContentRegistry Registry which is used for XContentParser + * @param searchFeatureDao Search feature DAO + * @param validationType Specified type for validation + * @param clock clock object to know when to timeout + * @param settings Node settings + * @param user User info + * @param context Analysis type + */ + public ModelValidationActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + Config config, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings, + User user, + AnalysisType context + ) { + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + this.listener = listener; + this.config = config; + this.requestTimeout = requestTimeout; + this.xContentRegistry = xContentRegistry; + this.validationType = validationType; + this.clock = clock; + this.settings = settings; + this.user = user; + this.context = context; + this.histogramAggHelper = new HistogramAggregationHelper(config, requestTimeout); + this.intervalCalculation = new IntervalCalculation(config, requestTimeout, client, clientUtil, user, context, clock); + // calculate the bounds in a lazy manner + this.timeRangeToSearchForConfiguredInterval = null; + this.latestTimeRetriever = new LatestTimeRetriever(config, requestTimeout, clientUtil, client, user, context, searchFeatureDao); + } + + public void start() { + ActionListener, Map>> latestTimeListener = ActionListener + .wrap( + latestEntityAttributes -> getSampleRangesForValidationChecks( + latestEntityAttributes.getLeft(), + config, + listener, + latestEntityAttributes.getRight() + ), + exception -> { + listener.onFailure(exception); + logger.error("Failed to create search request for last data point", exception); + } + ); + latestTimeRetriever.checkIfHC(latestTimeListener); + } + + private void getSampleRangesForValidationChecks( + Optional latestTime, + Config config, + ActionListener listener, + Map topEntity + ) { + if (!latestTime.isPresent() || latestTime.get() <= 0) { + listener + .onFailure( + new ValidationException( + CommonMessages.TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA, + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.MODEL + ) + ); + return; + } + long timeRangeEnd = Math.min(Instant.now().toEpochMilli(), latestTime.get()); + intervalCalculation + .findInterval( + timeRangeEnd, + topEntity, + ActionListener.wrap(interval -> processIntervalRecommendation(interval, latestTime.get()), listener::onFailure) + ); + } + + private void processIntervalRecommendation(IntervalTimeConfiguration interval, long latestTime) { + // if interval suggestion is null that means no interval could be found with all the configurations + // applied, our next step then is to check density just with the raw data and then add each configuration + // one at a time to try and find root cause of low density + if (interval == null) { + checkRawDataSparsity(latestTime); + } else { + if (((IntervalTimeConfiguration) config.getInterval()).gte(interval)) { + logger.info("Using the current interval there is enough dense data "); + // Check if there is a window delay recommendation if everything else is successful and send exception + if (Instant.now().toEpochMilli() - latestTime > histogramAggHelper.timeConfigToMilliSec(config.getWindowDelay())) { + sendWindowDelayRec(latestTime); + return; + } + // The rate of buckets with at least 1 doc with given interval is above the success rate + listener.onResponse(null); + return; + } + // return response with interval recommendation + listener + .onFailure( + new ValidationException( + CommonMessages.INTERVAL_REC + interval.getInterval(), + ValidationIssueType.DETECTION_INTERVAL, + ValidationAspect.MODEL, + interval + ) + ); + } + } + + public AggregationBuilder getBucketAggregation(long latestTime) { + IntervalTimeConfiguration interval = (IntervalTimeConfiguration) config.getInterval(); + long intervalInMinutes = IntervalTimeConfiguration.getIntervalInMinute(interval); + if (timeRangeToSearchForConfiguredInterval == null) { + timeRangeToSearchForConfiguredInterval = histogramAggHelper.getTimeRangeBounds(latestTime, intervalInMinutes * 60000); + } + + return histogramAggHelper.getBucketAggregation((int) intervalInMinutes, timeRangeToSearchForConfiguredInterval); + } + + private void checkRawDataSparsity(long latestTime) { + AggregationBuilder aggregation = getBucketAggregation(latestTime); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(aggregation).size(0).timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processRawDataResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + + public double processBucketAggregationResults(Histogram buckets, long latestTime) { + long intervalInMillis = config.getIntervalInMilliseconds(); + return histogramAggHelper.processBucketAggregationResults(buckets, intervalInMillis, config); + } + + private void processRawDataResults(SearchResponse response, long latestTime) { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < TimeSeriesSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException(CommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL) + ); + } else { + checkDataFilterSparsity(latestTime); + } + } + + private void checkDataFilterSparsity(long latestTime) { + AggregationBuilder aggregation = getBucketAggregation(latestTime); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + SearchSourceBuilder searchSourceBuilder = histogramAggHelper.getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processDataFilterResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + + private void processDataFilterResults(SearchResponse response, long latestTime) { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException( + CommonMessages.FILTER_QUERY_TOO_SPARSE, + ValidationIssueType.FILTER_QUERY, + ValidationAspect.MODEL + ) + ); + // blocks below are executed if data is dense enough with filter query applied. + // If HCAD then category fields will be added to bucket aggregation to see if they + // are the root cause of the issues and if not the feature queries will be checked for sparsity + } else if (config.isHighCardinality()) { + getTopEntityForCategoryField(latestTime); + } else { + try { + checkFeatureQueryDelegate(latestTime); + } catch (Exception ex) { + logger.error(ex); + listener.onFailure(ex); + } + } + } + + private void getTopEntityForCategoryField(long latestTime) { + ActionListener> getTopEntityListener = ActionListener + .wrap(topEntity -> checkCategoryFieldSparsity(topEntity, latestTime), exception -> { + listener.onFailure(exception); + logger.error("Failed to get top entity for categorical field", exception); + return; + }); + latestTimeRetriever.getTopEntity(getTopEntityListener); + } + + private void checkCategoryFieldSparsity(Map topEntity, long latestTime) { + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + for (Map.Entry entry : topEntity.entrySet()) { + query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); + } + AggregationBuilder aggregation = getBucketAggregation(latestTime); + SearchSourceBuilder searchSourceBuilder = histogramAggHelper.getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processTopEntityResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + + private void processTopEntityResults(SearchResponse response, long latestTime) { + Histogram aggregate = null; + try { + aggregate = histogramAggHelper.checkBucketResultErrors(response); + } catch (ValidationException e) { + listener.onFailure(e); + } + + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException(CommonMessages.CATEGORY_FIELD_TOO_SPARSE, ValidationIssueType.CATEGORY, ValidationAspect.MODEL) + ); + } else { + try { + checkFeatureQueryDelegate(latestTime); + } catch (Exception ex) { + logger.error(ex); + listener.onFailure(ex); + } + } + } + + private void checkFeatureQueryDelegate(long latestTime) throws IOException { + ActionListener> validateFeatureQueriesListener = ActionListener.wrap(response -> { + windowDelayRecommendation(latestTime); + }, exception -> { + listener + .onFailure(new ValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.MODEL)); + }); + MultiResponsesDelegateActionListener> multiFeatureQueriesResponseListener = + new MultiResponsesDelegateActionListener<>( + validateFeatureQueriesListener, + config.getFeatureAttributes().size(), + CommonMessages.FEATURE_QUERY_TOO_SPARSE, + false + ); + + for (Feature feature : config.getFeatureAttributes()) { + AggregationBuilder aggregation = getBucketAggregation(latestTime); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(config.getFilterQuery()); + List featureFields = ParseUtils.getFieldNamesForFeature(feature, xContentRegistry); + for (String featureField : featureFields) { + query.filter(QueryBuilders.existsQuery(featureField)); + } + SearchSourceBuilder searchSourceBuilder = histogramAggHelper.getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + try { + Histogram aggregate = histogramAggHelper.checkBucketResultErrors(response); + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate, latestTime); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + multiFeatureQueriesResponseListener + .onFailure( + new ValidationException( + CommonMessages.FEATURE_QUERY_TOO_SPARSE, + ValidationIssueType.FEATURE_ATTRIBUTES, + ValidationAspect.MODEL + ) + ); + } else { + multiFeatureQueriesResponseListener + .onResponse(new MergeableList<>(new ArrayList<>(Collections.singletonList(new double[] { fullBucketRate })))); + } + } catch (ValidationException e) { + listener.onFailure(e); + } + + }, e -> { + logger.error(e); + multiFeatureQueriesResponseListener + .onFailure(new OpenSearchStatusException(CommonMessages.FEATURE_QUERY_TOO_SPARSE, RestStatus.BAD_REQUEST, e)); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + context, + searchResponseListener + ); + } + } + + private void sendWindowDelayRec(long latestTimeInMillis) { + long minutesSinceLastStamp = (long) Math.ceil((Instant.now().toEpochMilli() - latestTimeInMillis) / 60000.0); + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.WINDOW_DELAY_REC, minutesSinceLastStamp, minutesSinceLastStamp), + ValidationIssueType.WINDOW_DELAY, + ValidationAspect.MODEL, + new IntervalTimeConfiguration(minutesSinceLastStamp, ChronoUnit.MINUTES) + ) + ); + } + + private void windowDelayRecommendation(long latestTime) { + // Check if there is a better window-delay to recommend and if one was recommended + // then send exception and return, otherwise continue to let user know data is too sparse as explained below + if (Instant.now().toEpochMilli() - latestTime > histogramAggHelper.timeConfigToMilliSec(config.getWindowDelay())) { + sendWindowDelayRec(latestTime); + return; + } + // This case has been reached if following conditions are met: + // 1. no interval recommendation was found that leads to a bucket success rate of >= 0.75 + // 2. bucket success rate with the given interval and just raw data is also below 0.75. + // 3. no single configuration during the following checks reduced the bucket success rate below 0.25 + // This means the rate with all configs applied or just raw data was below 0.75 but the rate when checking each configuration at + // a time was always above 0.25 meaning the best suggestion is to simply ingest more data or change interval since + // we have no more insight regarding the root cause of the lower density. + listener + .onFailure(new ValidationException(CommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL)); + } + +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/Processor.java b/src/main/java/org/opensearch/timeseries/rest/handler/Processor.java new file mode 100644 index 000000000..548f31bba --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/Processor.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest.handler; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; + +/** + * Represents a processor capable of initiating a certain process + * and then notifying a listener upon completion. + * + * @param the type of response expected after processing, which must be a subtype of ActionResponse. + */ +public interface Processor { + + /** + * Starts the processing action. Once the processing is completed, + * the provided listener is notified with the outcome. + * + * @param listener the listener to be notified upon the completion of the processing action. + */ + public void start(ActionListener listener); +} diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSetting.java b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSetting.java new file mode 100644 index 000000000..3e7499175 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSetting.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.settings; + +import static java.util.Collections.unmodifiableMap; +import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.NodeScope; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.settings.Setting; + +public class TimeSeriesEnabledSetting extends DynamicNumericSetting { + + /** + * Singleton instance + */ + private static TimeSeriesEnabledSetting INSTANCE; + + /** + * Settings name + */ + public static final String BREAKER_ENABLED = "plugins.timeseries.breaker.enabled"; + + public static final Map> settings = unmodifiableMap(new HashMap>() { + { + /** + * forecast breaker enable/disable setting + */ + put(BREAKER_ENABLED, Setting.boolSetting(BREAKER_ENABLED, true, NodeScope, Dynamic)); + } + }); + + private TimeSeriesEnabledSetting(Map> settings) { + super(settings); + } + + public static synchronized TimeSeriesEnabledSetting getInstance() { + if (INSTANCE == null) { + INSTANCE = new TimeSeriesEnabledSetting(settings); + } + return INSTANCE; + } + + /** + * Whether circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause a real-time job to be stopped. + * @return whether circuit breaker is enabled or not. + */ + public static boolean isBreakerEnabled() { + return TimeSeriesEnabledSetting.getInstance().getSettingValue(TimeSeriesEnabledSetting.BREAKER_ENABLED); + } +} diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java index 56bbe187a..d2935ae61 100644 --- a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java @@ -25,11 +25,34 @@ public class TimeSeriesSettings { public static final String JOBS_INDEX_MAPPING_FILE = "mappings/job.json"; - // 100,000 insertions costs roughly 1KB. + /** + * Memory Usage Estimation for a Map<String, Integer> with 100,000 entries: + * + * 1. HashMap Object Overhead: This can vary, but let's assume it's about 36 bytes. + * 2. Array Overhead: + * - The array size will be the nearest power of 2 greater than or equal to 100,000 / load factor. + * - Assuming a load factor of 0.75, the array size will be 2^17 = 131,072. + * - The memory usage will be 131,072 * 4 bytes = 524,288 bytes. + * 3. Entry Overhead: Each entry has an overhead of about 32 bytes (object header, hash code, and three references). + * 4. Key Overhead: + * - Each key has an overhead of about 36 bytes (object header, length, hash cache) plus the character data. + * - Assuming the character data is 64 bytes, the total key overhead per entry is 100 bytes. + * 5. Value Overhead: Each Integer object has an overhead of about 16 bytes (object header plus int value). + * + * Total Memory Usage Formula: + * Total Memory Usage = HashMap Object Overhead + Array Overhead + + * (Entry Overhead + Key Overhead + Value Overhead) * Number of Entries + * + * Plugging in the numbers: + * Total Memory Usage = 36 + 524,288 + (32 + 100 + 16) * 100,000 + * ≈ 14,965 kilobytes (≈ 15 MB) + * + * Note: + * This estimation is quite simplistic and the actual memory usage may be different based on the JVM implementation, + * the actual Map implementation being used, and other factors. + */ public static final int DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION = 100_000; - public static final double DOOR_KEEPER_FALSE_POSITIVE_RATE = 0.01; - // clean up door keeper every 60 intervals public static final int DOOR_KEEPER_MAINTENANCE_FREQ = 60; @@ -92,7 +115,7 @@ public class TimeSeriesSettings { public static int CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES = 200_000; /** - * ResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest + * ADResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). * Plus Java object size (12 bytes), we have roughly 1160 bytes per request * @@ -104,18 +127,18 @@ public class TimeSeriesSettings { public static int RESULT_WRITE_QUEUE_SIZE_IN_BYTES = 1160; /** - * FeatureRequest has entityName (# category fields * 256, the recommended limit - * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest - * fields including config Id(roughly 128 bytes), dataStartTimeMillis (long, + * FeatureRequest has entity (max 2 category fields * 256, the recommended limit + * of a keyword field length, 512 bytes), model Id (roughly 256 bytes), runOnce + * boolean (roughly 8 bytes), dataStartTimeMillis (long, * 8 bytes), and currentFeature (16 bytes, assume two features on average). - * Plus Java object size (12 bytes), we have roughly 932 bytes per request + * Plus Java object size (12 bytes), we have roughly 812 bytes per request * assuming we have 2 categorical fields (plan to support 2 categorical fields now). * We don't want the total size exceeds 0.1% of the heap. - * We can have at most 0.1% heap / 932 = heap / 932,000. + * We can have at most 0.1% heap / 812 = heap / 812,000. * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 932 = 1072 + * 10^ 6 / 812 = 1231 */ - public static int FEATURE_REQUEST_SIZE_IN_BYTES = 932; + public static int FEATURE_REQUEST_SIZE_IN_BYTES = 812; /** * CheckpointMaintainRequest has model Id (roughly 256 bytes), and QueuedRequest @@ -158,6 +181,11 @@ public class TimeSeriesSettings { // for a batch operation, we want all of the bounding box in-place for speed public static final double BATCH_BOUNDING_BOX_CACHE_RATIO = 1; + // feature processing + public static final int TRAIN_SAMPLE_TIME_RANGE_IN_HOURS = 24; + + public static final int MIN_TRAIN_SAMPLES = 512; + // ====================================== // Cold start setting // ====================================== @@ -209,4 +237,22 @@ public class TimeSeriesSettings { // such as "there are at least 10000 entities", the default is set to 10,000. That is, requests will count the // total entities up to 10,000. public static final int MAX_TOTAL_ENTITIES_TO_TRACK = 10_000; + + // ====================================== + // Validate Detector API setting + // ====================================== + public static final long TOP_VALIDATE_TIMEOUT_IN_MILLIS = 10_000; + + public static final double INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE = 0.75; + + public static final double INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER = 1.2; + + public static final long MAX_INTERVAL_REC_LENGTH_IN_MINUTES = 60L; + + public static final int MAX_DESCRIPTION_LENGTH = 1000; + + // ====================================== + // Cache setting + // ====================================== + public static final int DOOR_KEEPER_COUNT_THRESHOLD = 10; } diff --git a/src/main/java/org/opensearch/ad/stats/InternalStatNames.java b/src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java similarity index 95% rename from src/main/java/org/opensearch/ad/stats/InternalStatNames.java rename to src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java index 56ff012a5..356a7828d 100644 --- a/src/main/java/org/opensearch/ad/stats/InternalStatNames.java +++ b/src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.stats; /** * Enum containing names of all internal stats which will not be returned diff --git a/src/main/java/org/opensearch/timeseries/stats/StatNames.java b/src/main/java/org/opensearch/timeseries/stats/StatNames.java index a72e3f1b0..8ea32dffe 100644 --- a/src/main/java/org/opensearch/timeseries/stats/StatNames.java +++ b/src/main/java/org/opensearch/timeseries/stats/StatNames.java @@ -19,30 +19,46 @@ * AD stats REST API. */ public enum StatNames { - AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count"), - AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count"), - AD_HC_EXECUTE_REQUEST_COUNT("ad_hc_execute_request_count"), - AD_HC_EXECUTE_FAIL_COUNT("ad_hc_execute_failure_count"), - DETECTOR_COUNT("detector_count"), - SINGLE_ENTITY_DETECTOR_COUNT("single_entity_detector_count"), - MULTI_ENTITY_DETECTOR_COUNT("multi_entity_detector_count"), - ANOMALY_DETECTORS_INDEX_STATUS("anomaly_detectors_index_status"), - ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status"), - MODELS_CHECKPOINT_INDEX_STATUS("models_checkpoint_index_status"), - ANOMALY_DETECTION_JOB_INDEX_STATUS("anomaly_detection_job_index_status"), - ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status"), - MODEL_INFORMATION("models"), - AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count"), - AD_CANCELED_BATCH_TASK_COUNT("ad_canceled_batch_task_count"), - AD_TOTAL_BATCH_TASK_EXECUTION_COUNT("ad_total_batch_task_execution_count"), - AD_BATCH_TASK_FAILURE_COUNT("ad_batch_task_failure_count"), - MODEL_COUNT("model_count"), - MODEL_CORRUTPION_COUNT("model_corruption_count"); + // common stats + CONFIG_INDEX_STATUS("config_index_status", StatType.TIMESERIES), + JOB_INDEX_STATUS("job_index_status", StatType.TIMESERIES), + // AD stats + AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count", StatType.AD), + AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count", StatType.AD), + AD_HC_EXECUTE_REQUEST_COUNT("ad_hc_execute_request_count", StatType.AD), + AD_HC_EXECUTE_FAIL_COUNT("ad_hc_execute_failure_count", StatType.AD), + DETECTOR_COUNT("detector_count", StatType.AD), + SINGLE_STREAM_DETECTOR_COUNT("single_stream_detector_count", StatType.AD), + HC_DETECTOR_COUNT("hc_detector_count", StatType.AD), + ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status", StatType.AD), + AD_MODELS_CHECKPOINT_INDEX_STATUS("anomaly_models_checkpoint_index_status", StatType.AD), + ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status", StatType.AD), + MODEL_INFORMATION("models", StatType.AD), + AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count", StatType.AD), + AD_CANCELED_BATCH_TASK_COUNT("ad_canceled_batch_task_count", StatType.AD), + AD_TOTAL_BATCH_TASK_EXECUTION_COUNT("ad_total_batch_task_execution_count", StatType.AD), + AD_BATCH_TASK_FAILURE_COUNT("ad_batch_task_failure_count", StatType.AD), + MODEL_COUNT("model_count", StatType.AD), + AD_MODEL_CORRUTPION_COUNT("ad_model_corruption_count", StatType.AD), + // forecast stats + FORECAST_EXECUTE_REQUEST_COUNT("forecast_execute_request_count", StatType.FORECAST), + FORECAST_EXECUTE_FAIL_COUNT("forecast_execute_failure_count", StatType.FORECAST), + FORECAST_HC_EXECUTE_REQUEST_COUNT("forecast_hc_execute_request_count", StatType.FORECAST), + FORECAST_HC_EXECUTE_FAIL_COUNT("forecast_hc_execute_failure_count", StatType.FORECAST), + FORECAST_RESULTS_INDEX_STATUS("forecast_results_index_status", StatType.FORECAST), + FORECAST_MODELS_CHECKPOINT_INDEX_STATUS("forecast_models_checkpoint_index_status", StatType.FORECAST), + FORECAST_STATE_STATUS("forecastn_state_status", StatType.FORECAST), + FORECASTER_COUNT("forecaster_count", StatType.FORECAST), + SINGLE_STREAM_FORECASTER_COUNT("single_stream_forecaster_count", StatType.FORECAST), + HC_FORECASTER_COUNT("hc_forecaster_count", StatType.FORECAST), + FORECAST_MODEL_CORRUTPION_COUNT("forecast_model_corruption_count", StatType.FORECAST); - private String name; + private final String name; + private final StatType type; - StatNames(String name) { + StatNames(String name, StatType type) { this.name = name; + this.type = type; } /** @@ -54,6 +70,10 @@ public String getName() { return name; } + public StatType getType() { + return type; + } + /** * Get set of stat names * diff --git a/src/main/java/org/opensearch/timeseries/stats/StatType.java b/src/main/java/org/opensearch/timeseries/stats/StatType.java new file mode 100644 index 000000000..cca482bc7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/stats/StatType.java @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.stats; + +public enum StatType { + AD, + FORECAST, + TIMESERIES +} diff --git a/src/main/java/org/opensearch/timeseries/stats/Stats.java b/src/main/java/org/opensearch/timeseries/stats/Stats.java new file mode 100644 index 000000000..f9b168392 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/stats/Stats.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.stats; + +import java.util.HashMap; +import java.util.Map; + +public class Stats { + private Map> stats; + + /** + * Constructor + * + * @param stats Map of the stats that are to be kept + */ + public Stats(Map> stats) { + this.stats = stats; + } + + /** + * Get the stats + * + * @return all of the stats + */ + public Map> getStats() { + return stats; + } + + /** + * Get individual stat by stat name + * + * @param key Name of stat + * @return TimeSeriesStat + * @throws IllegalArgumentException thrown on illegal statName + */ + public TimeSeriesStat getStat(String key) throws IllegalArgumentException { + if (!stats.keySet().contains(key)) { + throw new IllegalArgumentException("Stat=\"" + key + "\" does not exist"); + } + return stats.get(key); + } + + /** + * Get a map of the stats that are kept at the node level + * + * @return Map of stats kept at the node level + */ + public Map> getNodeStats() { + return getClusterOrNodeStats(false); + } + + /** + * Get a map of the stats that are kept at the cluster level + * + * @return Map of stats kept at the cluster level + */ + public Map> getClusterStats() { + return getClusterOrNodeStats(true); + } + + private Map> getClusterOrNodeStats(Boolean getClusterStats) { + Map> statsMap = new HashMap<>(); + + for (Map.Entry> entry : stats.entrySet()) { + if (entry.getValue().isClusterLevel() == getClusterStats) { + statsMap.put(entry.getKey(), entry.getValue()); + } + } + return statsMap; + } +} diff --git a/src/main/java/org/opensearch/ad/stats/ADStat.java b/src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java similarity index 86% rename from src/main/java/org/opensearch/ad/stats/ADStat.java rename to src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java index 531205907..e10ab9127 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStat.java +++ b/src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java @@ -9,17 +9,17 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.stats; import java.util.function.Supplier; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; /** * Class represents a stat the plugin keeps track of */ -public class ADStat { +public class TimeSeriesStat { private Boolean clusterLevel; private Supplier supplier; @@ -29,7 +29,7 @@ public class ADStat { * @param clusterLevel whether the stat has clusterLevel scope or nodeLevel scope * @param supplier supplier that returns the stat's value */ - public ADStat(Boolean clusterLevel, Supplier supplier) { + public TimeSeriesStat(Boolean clusterLevel, Supplier supplier) { this.clusterLevel = clusterLevel; this.supplier = supplier; } diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java similarity index 95% rename from src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java index 39acd94ff..0953e9450 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.concurrent.atomic.LongAdder; import java.util.function.Supplier; diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java similarity index 92% rename from src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java index ab9177cb5..1da433108 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java @@ -9,11 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.function.Supplier; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.timeseries.util.IndexUtils; /** * IndexStatusSupplier provides the status of an index as the value diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java similarity index 94% rename from src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java index b39ecdde5..e5e60c6ba 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; diff --git a/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java b/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java index 5fe0c3850..8765e19c9 100644 --- a/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java +++ b/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java @@ -35,8 +35,8 @@ public class RealtimeTaskCache { // track last job run time, will clean up cache if no access after 2 intervals private long lastJobRunTime; - // detector interval in milliseconds. - private long detectorIntervalInMillis; + // interval in milliseconds. + private long intervalInMillis; // we query result index to check if there are any result generated for detector to tell whether it passed initialization of not. // To avoid repeated query when there is no data, record whether we have done that or not. @@ -47,7 +47,7 @@ public RealtimeTaskCache(String state, Float initProgress, String error, long de this.initProgress = initProgress; this.error = error; this.lastJobRunTime = Instant.now().toEpochMilli(); - this.detectorIntervalInMillis = detectorIntervalInMillis; + this.intervalInMillis = detectorIntervalInMillis; this.queriedResultIndex = false; } @@ -88,6 +88,6 @@ public void setQueriedResultIndex(boolean queriedResultIndex) { } public boolean expired() { - return lastJobRunTime + 2 * detectorIntervalInMillis < Instant.now().toEpochMilli(); + return lastJobRunTime + 2 * intervalInMillis < Instant.now().toEpochMilli(); } } diff --git a/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java b/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java index fe08f94c8..d0a87d9a2 100644 --- a/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java +++ b/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java @@ -15,9 +15,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.transport.TransportService; public class TaskCacheManager { private final Logger logger = LogManager.getLogger(TaskCacheManager.class); @@ -39,7 +43,7 @@ public class TaskCacheManager { protected volatile Integer maxCachedDeletedTask; /** * This field is to cache deleted detector IDs. Hourly cron will poll this queue - * and clean AD results. Check ADTaskManager#cleanResultOfDeletedConfig() + * and clean AD results. Check {@link ADTaskManager#cleanResultOfDeletedConfig} *

Node: any data node servers delete detector request

*/ protected Queue deletedConfigs; @@ -146,16 +150,16 @@ public boolean isRealtimeTaskChangeNeeded(String detectorId, String newState, Fl * * If realtime task cache doesn't exist, will do nothing. Next realtime job run will re-init * realtime task cache when it finds task cache not inited yet. - * Check ADTaskManager#initCacheWithCleanupIfRequired(String, AnomalyDetector, TransportService, ActionListener), - * ADTaskManager#updateLatestRealtimeTaskOnCoordinatingNode(String, String, Long, Long, String, ActionListener) * - * @param detectorId detector id + * Check {@link TaskManager#initRealtimeTaskCacheAndCleanupStaleCache(String, Config, TransportService, ActionListener)} + * + * @param configId detector id * @param newState new task state * @param newInitProgress new init progress * @param newError new error */ - public void updateRealtimeTaskCache(String detectorId, String newState, Float newInitProgress, String newError) { - RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); + public void updateRealtimeTaskCache(String configId, String newState, Float newInitProgress, String newError) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(configId); if (realtimeTaskCache != null) { if (newState != null) { realtimeTaskCache.setState(newState); @@ -168,47 +172,47 @@ public void updateRealtimeTaskCache(String detectorId, String newState, Float ne } if (newState != null && !TaskState.NOT_ENDED_STATES.contains(newState)) { // If task is done, will remove its realtime task cache. - logger.info("Realtime task done with state {}, remove RT task cache for detector ", newState, detectorId); - removeRealtimeTaskCache(detectorId); + logger.info("Realtime task done with state {}, remove RT task cache for config ", newState, configId); + removeRealtimeTaskCache(configId); } } else { - logger.debug("Realtime task cache is not inited yet for detector {}", detectorId); + logger.debug("Realtime task cache is not inited yet for config {}", configId); } } - public void refreshRealtimeJobRunTime(String detectorId) { - RealtimeTaskCache taskCache = realtimeTaskCaches.get(detectorId); + public void refreshRealtimeJobRunTime(String configId) { + RealtimeTaskCache taskCache = realtimeTaskCaches.get(configId); if (taskCache != null) { taskCache.setLastJobRunTime(Instant.now().toEpochMilli()); } } /** - * Get detector IDs from realtime task cache. - * @return array of detector id + * Get config IDs from realtime task cache. + * @return array of config id */ - public String[] getDetectorIdsInRealtimeTaskCache() { + public String[] getConfigIdsInRealtimeTaskCache() { return realtimeTaskCaches.keySet().toArray(new String[0]); } /** * Remove detector's realtime task from cache. - * @param detectorId detector id + * @param configId config id */ - public void removeRealtimeTaskCache(String detectorId) { - if (realtimeTaskCaches.containsKey(detectorId)) { - logger.info("Delete realtime cache for detector {}", detectorId); - realtimeTaskCaches.remove(detectorId); + public void removeRealtimeTaskCache(String configId) { + if (realtimeTaskCaches.containsKey(configId)) { + logger.info("Delete realtime cache for config {}", configId); + realtimeTaskCaches.remove(configId); } } /** - * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * We query result index to check if there are any result generated for config to tell whether it passed initialization of not. * To avoid repeated query when there is no data, record whether we have done that or not. - * @param id detector id + * @param configId config id */ - public void markResultIndexQueried(String id) { - RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + public void markResultIndexQueried(String configId) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(configId); // we initialize a real time cache at the beginning of AnomalyResultTransportAction if it // cannot be found. If the cache is empty, we will return early and wait it for it to be // initialized. @@ -218,13 +222,13 @@ public void markResultIndexQueried(String id) { } /** - * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * We query result index to check if there are any result generated for config to tell whether it passed initialization of not. * - * @param id detector id + * @param configId config id * @return whether we have queried result index or not. */ - public boolean hasQueriedResultIndex(String id) { - RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + public boolean hasQueriedResultIndex(String configId) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(configId); if (realtimeTaskCache != null) { return realtimeTaskCache.hasQueriedResultIndex(); } diff --git a/src/main/java/org/opensearch/timeseries/task/TaskManager.java b/src/main/java/org/opensearch/timeseries/task/TaskManager.java new file mode 100644 index 000000000..7424ffb13 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/task/TaskManager.java @@ -0,0 +1,1085 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.task; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.CONFIG_IS_RUNNING; +import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; +import static org.opensearch.timeseries.model.TaskType.taskTypeToString; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.UpdateByQueryAction; +import org.opensearch.index.reindex.UpdateByQueryRequest; +import org.opensearch.script.Script; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TaskCancelledException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.function.ResponseTransformer; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public abstract class TaskManager & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + protected static int DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS = 5; + + private final Logger logger = LogManager.getLogger(TaskManager.class); + + protected final TaskCacheManagerType taskCacheManager; + protected final ClusterService clusterService; + protected final Client client; + protected final String stateIndex; + private final List realTimeTaskTypes; + private final List historicalTaskTypes; + private final List runOnceTaskTypes; + protected final IndexManagementType indexManagement; + protected final NodeStateManager nodeStateManager; + protected final AnalysisType analysisType; + protected final NamedXContentRegistry xContentRegistry; + protected final String configIdFieldName; + + protected volatile Integer maxOldTaskDocsPerConfig; + + protected final ThreadPool threadPool; + private final String allResultIndexPattern; + private final String batchTaskThreadPoolName; + private volatile boolean deleteResultWhenDeleteConfig; + private final TaskState stopped; + + public TaskManager( + TaskCacheManagerType taskCacheManager, + ClusterService clusterService, + Client client, + String stateIndex, + List realTimeTaskTypes, + List historicalTaskTypes, + List runOnceTaskTypes, + IndexManagementType indexManagement, + NodeStateManager nodeStateManager, + AnalysisType analysisType, + NamedXContentRegistry xContentRegistry, + String configIdFieldName, + Setting maxOldADTaskDocsPerConfigSetting, + Settings settings, + ThreadPool threadPool, + String allResultIndexPattern, + String batchTaskThreadPoolName, + Setting deleteResultWhenDeleteConfigSetting, + TaskState stopped + ) { + this.taskCacheManager = taskCacheManager; + this.clusterService = clusterService; + this.client = client; + this.stateIndex = stateIndex; + this.realTimeTaskTypes = realTimeTaskTypes; + this.historicalTaskTypes = historicalTaskTypes; + this.runOnceTaskTypes = runOnceTaskTypes; + this.indexManagement = indexManagement; + this.nodeStateManager = nodeStateManager; + this.analysisType = analysisType; + this.xContentRegistry = xContentRegistry; + this.configIdFieldName = configIdFieldName; + + this.maxOldTaskDocsPerConfig = maxOldADTaskDocsPerConfigSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxOldADTaskDocsPerConfigSetting, it -> maxOldTaskDocsPerConfig = it); + + this.threadPool = threadPool; + this.allResultIndexPattern = allResultIndexPattern; + this.batchTaskThreadPoolName = batchTaskThreadPoolName; + + this.deleteResultWhenDeleteConfig = deleteResultWhenDeleteConfigSetting.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(deleteResultWhenDeleteConfigSetting, it -> deleteResultWhenDeleteConfig = it); + + this.stopped = stopped; + } + + public boolean skipUpdateRealtimeTask(String configId, String error) { + RealtimeTaskCache realtimeTaskCache = taskCacheManager.getRealtimeTaskCache(configId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() == 1.0 + && Objects.equals(error, realtimeTaskCache.getError()); + } + + public boolean isHCRealtimeTaskStartInitializing(String detectorId) { + RealtimeTaskCache realtimeTaskCache = taskCacheManager.getRealtimeTaskCache(detectorId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() > 0; + } + + /** + * Maintain running realtime tasks. Check if realtime task cache expires or not. Remove realtime + * task cache directly if expired. + */ + public void maintainRunningRealtimeTasks() { + String[] configIds = taskCacheManager.getConfigIdsInRealtimeTaskCache(); + if (configIds == null || configIds.length == 0) { + return; + } + for (int i = 0; i < configIds.length; i++) { + String configId = configIds[i]; + RealtimeTaskCache taskCache = taskCacheManager.getRealtimeTaskCache(configId); + if (taskCache != null && taskCache.expired()) { + taskCacheManager.removeRealtimeTaskCache(configId); + } + } + } + + public void refreshRealtimeJobRunTime(String detectorId) { + taskCacheManager.refreshRealtimeJobRunTime(detectorId); + } + + public void removeRealtimeTaskCache(String detectorId) { + taskCacheManager.removeRealtimeTaskCache(detectorId); + } + + /** + * Update realtime task cache on realtime config's coordinating node. + * + * @param configId config id + * @param state new state + * @param rcfTotalUpdates rcf total updates + * @param intervalInMinutes config interval in minutes + * @param error error + * @param listener action listener + */ + public void updateLatestRealtimeTaskOnCoordinatingNode( + String configId, + String state, + Long rcfTotalUpdates, + Long intervalInMinutes, + String error, + ActionListener listener + ) { + Float initProgress = null; + String newState = null; + // calculate init progress and task state with RCF total updates + if (intervalInMinutes != null && rcfTotalUpdates != null) { + newState = TaskState.INIT.name(); + if (rcfTotalUpdates < TimeSeriesSettings.NUM_MIN_SAMPLES) { + initProgress = (float) rcfTotalUpdates / TimeSeriesSettings.NUM_MIN_SAMPLES; + } else { + newState = TaskState.RUNNING.name(); + initProgress = 1.0f; + } + } + // Check if new state is not null and override state calculated with rcf total updates + if (state != null) { + newState = state; + } + + error = Optional.ofNullable(error).orElse(""); + if (!taskCacheManager.isRealtimeTaskChangeNeeded(configId, newState, initProgress, error)) { + // If task not changed, no need to update, just return + listener.onResponse(null); + return; + } + Map updatedFields = new HashMap<>(); + updatedFields.put(TimeSeriesTask.COORDINATING_NODE_FIELD, clusterService.localNode().getId()); + if (initProgress != null) { + updatedFields.put(TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress); + updatedFields + .put( + TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD, + Math.max(0, TimeSeriesSettings.NUM_MIN_SAMPLES - rcfTotalUpdates) * intervalInMinutes + ); + } + if (newState != null) { + updatedFields.put(TimeSeriesTask.STATE_FIELD, newState); + } + if (error != null) { + updatedFields.put(TimeSeriesTask.ERROR_FIELD, error); + } + Float finalInitProgress = initProgress; + // Variable used in lambda expression should be final or effectively final + String finalError = error; + String finalNewState = newState; + updateLatestTask(configId, realTimeTaskTypes, updatedFields, ActionListener.wrap(r -> { + logger.debug("Updated latest realtime AD task successfully for config {}", configId); + taskCacheManager.updateRealtimeTaskCache(configId, finalNewState, finalInitProgress, finalError); + listener.onResponse(r); + }, e -> { + logger.error("Failed to update realtime task for config " + configId, e); + listener.onFailure(e); + })); + } + + /** + * Update latest task of a config. + * + * @param configId config id + * @param taskTypes task types + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateLatestTask( + String configId, + List taskTypes, + Map updatedFields, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(configId, taskTypes, (task) -> { + if (task.isPresent()) { + updateTask(task.get().getTaskId(), updatedFields, listener); + } else { + listener.onFailure(new ResourceNotFoundException(configId, CommonMessages.CAN_NOT_FIND_LATEST_TASK)); + } + }, null, false, listener); + } + + public void getAndExecuteOnLatestConfigLevelTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + TransportService transportService, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(config.getId(), getTaskTypes(dateRange), (task) -> { + if (!task.isPresent() || task.get().isDone()) { + updateLatestFlagOfOldTasksAndCreateNewTask(config, dateRange, runOnce, user, TaskState.CREATED, listener); + } else { + listener.onFailure(new OpenSearchStatusException(CONFIG_IS_RUNNING, RestStatus.BAD_REQUEST)); + } + }, transportService, true, listener); + } + + public void updateLatestFlagOfOldTasksAndCreateNewTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + TaskState initialState, + ActionListener listener + ) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(stateIndex); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, config.getId())); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + // make sure we reset all latest task as false when user switch from single entity to HC, vice versa. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(getTaskTypes(dateRange, true, runOnce)))); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", TimeSeriesTask.IS_LATEST_FIELD, false); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (bulkFailures.isEmpty()) { + // Realtime AD coordinating node is chosen by job scheduler, we won't know it until realtime AD job + // runs. Just set realtime AD coordinating node as null here, and AD job runner will reset correct + // coordinating node once realtime job starts. + // For historical analysis, this method will be called on coordinating node, so we can set coordinating + // node as local node. + String coordinatingNode = dateRange == null ? null : clusterService.localNode().getId(); + createNewTask(config, dateRange, runOnce, user, coordinatingNode, initialState, listener); + } else { + logger.error("Failed to update old task's state for detector: {}, response: {} ", config.getId(), r.toString()); + listener.onFailure(bulkFailures.get(0).getCause()); + } + }, e -> { + logger.error("Failed to reset old tasks as not latest for detector " + config.getId(), e); + listener.onFailure(e); + })); + } + + /** + * Get latest task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestConfigLevelTask( + String configId, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestConfigTask(configId, null, null, taskTypes, function, transportService, resetTaskState, listener); + } + + /** + * Get one latest task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param parentTaskId parent task id + * @param entity entity value + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestConfigTask( + String configId, + String parentTaskId, + Entity entity, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestTasks(configId, parentTaskId, entity, taskTypes, (taskList) -> { + if (taskList != null && taskList.size() > 0) { + function.accept(Optional.ofNullable(taskList.get(0))); + } else { + function.accept(Optional.empty()); + } + }, transportService, resetTaskState, 1, listener); + } + + public List getTaskTypes(DateRange dateRange) { + return getTaskTypes(dateRange, false, false); + } + + /** + * Update latest realtime task. + * + * @param configId config id + * @param state task state + * @param error error + * @param transportService transport service + * @param listener action listener + */ + public void stopLatestRealtimeTask( + String configId, + TaskState state, + Exception error, + TransportService transportService, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(configId, realTimeTaskTypes, (adTask) -> { + if (adTask.isPresent() && !adTask.get().isDone()) { + Map updatedFields = new HashMap<>(); + updatedFields.put(TimeSeriesTask.STATE_FIELD, state.name()); + if (error != null) { + updatedFields.put(TimeSeriesTask.ERROR_FIELD, error.getMessage()); + } + ExecutorFunction function = () -> updateTask(adTask.get().getTaskId(), updatedFields, ActionListener.wrap(r -> { + if (error == null) { + listener.onResponse(new JobResponse(configId)); + } else { + listener.onFailure(error); + } + }, e -> { listener.onFailure(e); })); + + String coordinatingNode = adTask.get().getCoordinatingNode(); + if (coordinatingNode != null && transportService != null) { + cleanConfigCache(adTask.get(), transportService, function, listener); + } else { + function.execute(); + } + } else { + listener.onFailure(new OpenSearchStatusException("job is already stopped: " + configId, RestStatus.OK)); + } + }, null, false, listener); + } + + protected void resetTaskStateAsStopped( + TimeSeriesTask task, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + cleanConfigCache(task, transportService, () -> { + String taskId = task.getTaskId(); + Map updatedFields = ImmutableMap.of(TimeSeriesTask.STATE_FIELD, stopped.name()); + updateTask(taskId, updatedFields, ActionListener.wrap(r -> { + task.setState(stopped.name()); + if (function != null) { + function.execute(); + } + // For realtime anomaly detection, we only create config level task, no entity level realtime task. + if (isHistoricalHCTask(task)) { + // Reset running entity tasks as STOPPED + resetEntityTasksAsStopped(taskId); + } + }, e -> { + logger.error("Failed to update task state as stopped for task " + taskId, e); + listener.onFailure(e); + })); + }, listener); + } + + /** + * Get latest config tasks and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param parentTaskId parent task id + * @param entity entity value + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param size return how many tasks + * @param listener action listener + * @param response type of action listener + */ + public void getAndExecuteOnLatestTasks( + String configId, + String parentTaskId, + Entity entity, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + int size, + ActionListener listener + ) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, configId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + if (parentTaskId != null) { + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId)); + } + if (taskTypes != null && taskTypes.size() > 0) { + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, TaskType.taskTypeToString(taskTypes))); + } + if (entity != null && !ParseUtils.isNullOrEmpty(entity.getAttributes())) { + String path = "entity"; + String entityKeyFieldName = path + ".name"; + String entityValueFieldName = path + ".value"; + + for (Map.Entry attribute : entity.getAttributes().entrySet()) { + BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); + TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); + + entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); + query.filter(nestedQueryBuilder); + } + } + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(query).sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(sourceBuilder); + searchRequest.indices(stateIndex); + + client.search(searchRequest, ActionListener.wrap(r -> { + // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 + // getTotalHits will be null when we track_total_hits is false in the query request. + // Add more checking here to cover some unknown cases. + List tsTasks = new ArrayList<>(); + if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + // don't throw exception here as consumer functions need to handle missing task + // in different way. + function.accept(tsTasks); + return; + } + BiCheckedFunction parserMethod = getTaskParser(); + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + TaskClass tsTask = parserMethod.apply(parser, searchHit.getId()); + tsTasks.add(tsTask); + } catch (Exception e) { + String message = "Failed to parse task for config " + configId + ", task id " + searchHit.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + if (resetTaskState) { + resetLatestConfigTaskState(tsTasks, function, transportService, listener); + } else { + function.accept(tsTasks); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.accept(new ArrayList<>()); + } else { + logger.error("Failed to search task for config " + configId, e); + listener.onFailure(e); + } + })); + } + + protected void resetRealtimeConfigTaskState( + List runningRealtimeTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + if (ParseUtils.isNullOrEmpty(runningRealtimeTasks)) { + function.execute(); + return; + } + TimeSeriesTask tsTask = runningRealtimeTasks.get(0); + String configId = tsTask.getConfigId(); + GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(configId); + client.get(getJobRequest, ActionListener.wrap(r -> { + if (r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + if (!job.isEnabled()) { + logger.debug("job is disabled, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } else { + function.execute(); + } + } catch (IOException e) { + logger.error(" Failed to parse job " + configId, e); + listener.onFailure(e); + } + } else { + logger.debug("job is not found, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + logger.debug("job is not found, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } else { + logger.error("Fail to get realtime job for config " + configId, e); + listener.onFailure(e); + } + })); + } + + /** + * Handle exceptions for task. Update task state and record error message. + * + * @param task AD task + * @param e exception + */ + public void handleTaskException(TaskClass task, Exception e) { + // TODO: handle timeout exception + String state = TaskState.FAILED.name(); + Map updatedFields = new HashMap<>(); + if (e instanceof DuplicateTaskException) { + // If user send multiple start detector request, we will meet race condition. + // Cache manager will put first request in cache and throw DuplicateTaskException + // for the second request. We will delete the second task. + logger + .warn( + "There is already one running task for config, configId:" + + task.getConfigId() + + ". Will delete task " + + task.getTaskId() + ); + deleteTask(task.getTaskId()); + return; + } + if (e instanceof TaskCancelledException) { + logger.info("task cancelled, taskId: {}, configId: {}", task.getTaskId(), task.getConfigId()); + state = stopped.name(); + String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); + if (stoppedBy != null) { + updatedFields.put(TimeSeriesTask.STOPPED_BY_FIELD, stoppedBy); + } + } else { + logger.error("Failed to execute batch task, task id: " + task.getTaskId() + ", config id: " + task.getConfigId(), e); + } + updatedFields.put(TimeSeriesTask.ERROR_FIELD, ExceptionUtil.getErrorMessage(e)); + updatedFields.put(TimeSeriesTask.STATE_FIELD, state); + updatedFields.put(TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli()); + updateTask(task.getTaskId(), updatedFields); + } + + /** + * Update task with specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateTask(String taskId, Map updatedFields) { + updateTask(taskId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + logger.debug("Updated task successfully: {}, task id: {}", response.status(), taskId); + } else { + logger.error("Failed to update task {}, status: {}", taskId, response.status()); + } + }, e -> { logger.error("Failed to update task: " + taskId, e); })); + } + + /** + * Update task for specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateTask(String taskId, Map updatedFields, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(stateIndex, taskId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updatedContent.put(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, listener); + } + + /** + * Delete task with task id. + * + * @param taskId task id + */ + public void deleteTask(String taskId) { + deleteTask(taskId, ActionListener.wrap(r -> { logger.info("Deleted task {} with status: {}", taskId, r.status()); }, e -> { + logger.error("Failed to delete task " + taskId, e); + })); + } + + /** + * Delete task with task id. + * + * @param taskId task id + * @param listener action listener + */ + public void deleteTask(String taskId, ActionListener listener) { + DeleteRequest deleteRequest = new DeleteRequest(stateIndex, taskId); + client.delete(deleteRequest, listener); + } + + /** + * Create config task directly without checking index exists of not. + * [Important!] Make sure listener returns in function + * + * @param tsTask Time series task + * @param function consumer function + * @param listener action listener + * @param action listener response type + */ + public void createTaskDirectly(TaskClass tsTask, Consumer function, ActionListener listener) { + IndexRequest request = new IndexRequest(stateIndex); + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + request + .source(tsTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { + logger.error("Failed to create task for config " + tsTask.getConfigId(), e); + listener.onFailure(e); + })); + } catch (Exception e) { + logger.error("Failed to create task for config " + tsTask.getConfigId(), e); + listener.onFailure(e); + } + } + + protected void cleanOldConfigTaskDocs( + IndexResponse response, + TaskClass tsTask, + ResponseTransformer responseTransformer, + ActionListener delegatedListener + ) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, tsTask.getConfigId())); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, false)); + + if (tsTask.isHistoricalTask()) { + // If historical task, only delete detector level task. It may take longer time to delete entity tasks. + // We will delete child task (entity task) of config level task in hourly cron job. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(historicalTaskTypes))); + } else if (tsTask.isRunOnceTask()) { + // We don't have entity level task for run once detection, so will delete all tasks. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(runOnceTaskTypes))); + } else { + // We don't have entity level task for realtime detection, so will delete all tasks. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(realTimeTaskTypes))); + } + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder + .query(query) + .sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC) + // Search query "from" starts from maxOldTaskDocsPerConfig. + .from(maxOldTaskDocsPerConfig) + .size(MAX_OLD_AD_TASK_DOCS); + searchRequest.source(sourceBuilder).indices(stateIndex); + String configId = tsTask.getConfigId(); + deleteTaskDocs(configId, searchRequest, () -> { + if (tsTask.isHistoricalTask()) { + // run batch result action for historical analysis + runBatchResultAction(response, tsTask, responseTransformer, delegatedListener); + } else { + // use the responseTransformer to transform the response + T transformedResponse = responseTransformer.transform(response); + delegatedListener.onResponse(transformedResponse); + } + }, delegatedListener); + } + + public void deleteTaskDocs(String configId, SearchRequest searchRequest, ExecutorFunction function, ActionListener listener) { + ActionListener searchListener = ActionListener.wrap(r -> { + Iterator iterator = r.getHits().iterator(); + if (iterator.hasNext()) { + BulkRequest bulkRequest = new BulkRequest(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + TimeSeriesTask task = null; + if (analysisType.isAD()) { + task = ADTask.parse(parser, searchHit.getId()); + } else { + task = ForecastTask.parse(parser, searchHit.getId()); + } + + logger.debug("Delete old task: {} of config: {}", task.getTaskId(), task.getConfigId()); + bulkRequest.add(new DeleteRequest(stateIndex).id(task.getTaskId())); + } catch (Exception e) { + listener.onFailure(e); + } + } + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + logger.info("Old tasks deleted for config {}", configId); + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.debug("Add config task into cache. Task id: {}", bulkItemResponse.getId()); + // add deleted task in cache and delete its child tasks and results + taskCacheManager.addDeletedTask(bulkItemResponse.getId()); + } + } + } + // delete child tasks and results of this task + cleanChildTasksAndResultsOfDeletedTask(); + function.execute(); + }, e -> { + logger.warn("Failed to clean tasks for config " + configId, e); + listener.onFailure(e); + })); + } else { + function.execute(); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.execute(); + } else { + listener.onFailure(e); + } + }); + + client.search(searchRequest, searchListener); + } + + /** + * Poll deleted config task from cache and delete its child tasks and results. + */ + public void cleanChildTasksAndResultsOfDeletedTask() { + if (!taskCacheManager.hasDeletedTask()) { + return; + } + threadPool.schedule(() -> { + String taskId = taskCacheManager.pollDeletedTask(); + if (taskId == null) { + return; + } + DeleteByQueryRequest deleteResultsRequest = new DeleteByQueryRequest(allResultIndexPattern); + deleteResultsRequest.setQuery(new TermsQueryBuilder(CommonName.TASK_ID_FIELD, taskId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteResultsRequest, ActionListener.wrap(res -> { + logger.debug("Successfully deleted results of task " + taskId); + DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(stateIndex); + deleteChildTasksRequest.setQuery(new TermsQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, taskId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { + logger.debug("Successfully deleted child tasks of task " + taskId); + cleanChildTasksAndResultsOfDeletedTask(); + }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); + }, ex -> { logger.error("Failed to delete results for task " + taskId, ex); })); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), batchTaskThreadPoolName); + } + + protected void resetEntityTasksAsStopped(String configTaskId) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(stateIndex); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, configTaskId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, ADTaskType.HISTORICAL_HC_ENTITY.name())); + query.filter(new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, NOT_ENDED_STATES)); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", TimeSeriesTask.STATE_FIELD, TaskState.INACTIVE.name()); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (ParseUtils.isNullOrEmpty(bulkFailures)) { + logger.debug("Updated {} child entity tasks state for config task {}", r.getUpdated(), configTaskId); + } else { + logger.error("Failed to update child entity task's state for config task {} ", configTaskId); + } + }, e -> logger.error("Exception happened when update child entity task's state for config task " + configTaskId, e))); + } + + /** + * Set old task's latest flag as false. + * @param tasks list of tasks + */ + public void resetLatestFlagAsFalse(List tasks) { + if (tasks == null || tasks.size() == 0) { + return; + } + BulkRequest bulkRequest = new BulkRequest(); + tasks.forEach(task -> { + try { + task.setLatest(false); + task.setLastUpdateTime(Instant.now()); + IndexRequest indexRequest = new IndexRequest(stateIndex) + .id(task.getTaskId()) + .source(task.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (Exception e) { + logger.error("Fail to parse task task to XContent, task id " + task.getTaskId(), e); + } + }); + + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.warn("Reset tasks latest flag as false Successfully. Task id: {}", bulkItemResponse.getId()); + } else { + logger.warn("Failed to reset tasks latest flag as false. Task id: " + bulkItemResponse.getId()); + } + } + } + }, e -> { logger.warn("Failed to reset AD tasks latest flag as false", e); })); + } + + /** + * Delete tasks docs. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param function time series function + * @param listener action listener + */ + public void deleteTasks(String configId, ExecutorFunction function, ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest(stateIndex); + + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, configId)); + + request.setQuery(query); + client.execute(DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(r -> { + if (r.getBulkFailures() == null || r.getBulkFailures().size() == 0) { + logger.info("tasks deleted for config {}", configId); + deleteResultOfConfig(configId); + function.execute(); + } else { + listener.onFailure(new OpenSearchStatusException("Failed to delete all tasks", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + logger.info("Failed to delete tasks for " + configId, e); + if (e instanceof IndexNotFoundException) { + deleteResultOfConfig(configId); + function.execute(); + } else { + listener.onFailure(e); + } + })); + } + + public void deleteResultOfConfig(String configId) { + if (!deleteResultWhenDeleteConfig) { + logger.info("Won't delete result for {} as delete result setting is disabled", configId); + return; + } + logger.info("Start to delete results of config {}", configId); + DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(allResultIndexPattern); + deleteADResultsRequest.setQuery(new TermQueryBuilder(configIdFieldName, configId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(response -> { + logger.debug("Successfully deleted results of config " + configId); + }, exception -> { + logger.error("Failed to delete results of config " + configId, exception); + taskCacheManager.addDeletedConfig(configId); + })); + } + + /** + * Clean results of deleted config. + */ + public void cleanResultOfDeletedConfig() { + String detectorId = taskCacheManager.pollDeletedConfig(); + if (detectorId != null) { + deleteResultOfConfig(detectorId); + } + } + + public abstract void startHistorical( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ); + + protected abstract TaskType getTaskType(Config config, DateRange dateRange, boolean runOnce); + + protected abstract void createNewTask( + Config config, + DateRange dateRange, + boolean runOnce, + User user, + String coordinatingNode, + TaskState initialState, + ActionListener listener + ); + + public abstract void cleanConfigCache( + TimeSeriesTask task, + TransportService transportService, + ExecutorFunction function, + ActionListener listener + ); + + protected abstract boolean isHistoricalHCTask(TimeSeriesTask task); + + protected abstract void resetLatestConfigTaskState( + List tasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ); + + protected abstract void onIndexConfigTaskResponse( + IndexResponse response, + TaskClass adTask, + BiConsumer> function, + ActionListener listener + ); + + protected abstract void runBatchResultAction( + IndexResponse response, + TaskClass tsTask, + ResponseTransformer responseTransformer, + ActionListener listener + ); + + protected abstract BiCheckedFunction getTaskParser(); + + /** + * the function initializes the real time cache and only performs cleanup if it is deemed necessary. + * @param configId config id + * @param config config accessor + * @param transportService Transport service + * @param listener listener to return back init success or not + */ + public abstract void initRealtimeTaskCacheAndCleanupStaleCache( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ); + + public abstract void createRunOnceTaskAndCleanupStaleTasks( + String configId, + Config config, + TransportService transportService, + ActionListener listener + ); + + public abstract List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, boolean runOnce); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseDeleteConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteConfigTransportAction.java new file mode 100644 index 000000000..d7e1c355f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteConfigTransportAction.java @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_DELETE_CONFIG; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.io.IOException; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public abstract class BaseDeleteConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager, ConfigType extends Config> + extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(BaseDeleteConfigTransportAction.class); + + private final Client client; + private final ClusterService clusterService; + private final TransportService transportService; + private NamedXContentRegistry xContentRegistry; + private final TaskManagerType taskManager; + private volatile Boolean filterByEnabled; + private final NodeStateManager nodeStateManager; + private final AnalysisType analysisType; + private final String stateIndex; + private final Class configTypeClass; + private final List batchTaskTypes; + + public BaseDeleteConfigTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + NodeStateManager nodeStateManager, + TaskManagerType taskManager, + String deleteConfigAction, + Setting filterByBackendRoleSetting, + AnalysisType analysisType, + String stateIndex, + Class configTypeClass, + List historicalTaskTypes + ) { + super(deleteConfigAction, transportService, actionFilters, DeleteConfigRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.taskManager = taskManager; + this.nodeStateManager = nodeStateManager; + filterByEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByEnabled = it); + + this.analysisType = analysisType; + this.stateIndex = stateIndex; + this.configTypeClass = configTypeClass; + this.batchTaskTypes = historicalTaskTypes; + } + + @Override + protected void doExecute(Task task, DeleteConfigRequest request, ActionListener actionListener) { + String configId = request.getConfigID(); + LOG.info("Delete job {}", configId); + User user = ParseUtils.getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_DELETE_CONFIG); + // By the time request reaches here, the user permissions are validated by Security plugin. + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configId, + filterByEnabled, + listener, + (input) -> nodeStateManager.getConfig(configId, analysisType, config -> { + if (!config.isPresent()) { + // In a mixed cluster, if delete detector request routes to node running AD1.0, then it will + // not delete detector tasks. User can re-delete these deleted detector after cluster upgraded, + // in that case, the detector is not present. + LOG.info("Can't find config {}", configId); + taskManager.deleteTasks(configId, () -> deleteJobDoc(configId, listener), listener); + return; + } + // Check if there is realtime job or batch analysis task running. If none of these running, we + // can delete the config. + getJob(configId, listener, () -> { + taskManager.getAndExecuteOnLatestConfigLevelTask(configId, batchTaskTypes, configTask -> { + if (configTask.isPresent() && !configTask.get().isDone()) { + listener + .onFailure(new OpenSearchStatusException("Run once or historical is running", RestStatus.BAD_REQUEST)); + } else { + taskManager.deleteTasks(configId, () -> deleteJobDoc(configId, listener), listener); + } + // false means don't reset task state as inactive/stopped state. We are checking if task has finished or not. + // So no need to reset task state. + }, transportService, false, listener); + }); + }, listener), + client, + clusterService, + xContentRegistry, + configTypeClass + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void deleteJobDoc(String configId, ActionListener listener) { + LOG.info("Delete job {}", configId); + DeleteRequest deleteRequest = new DeleteRequest(CommonName.JOB_INDEX, configId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, ActionListener.wrap(response -> { + if (response.getResult() == DocWriteResponse.Result.DELETED || response.getResult() == DocWriteResponse.Result.NOT_FOUND) { + deleteStateDoc(configId, listener); + } else { + String message = "Fail to delete job " + configId; + LOG.error(message); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, exception -> { + LOG.error("Failed to delete job for " + configId, exception); + if (exception instanceof IndexNotFoundException) { + deleteStateDoc(configId, listener); + } else { + LOG.error("Failed to delete job", exception); + listener.onFailure(exception); + } + })); + } + + private void deleteStateDoc(String configId, ActionListener listener) { + LOG.info("Delete config state {}", configId); + DeleteRequest deleteRequest = new DeleteRequest(stateIndex, configId); + client.delete(deleteRequest, ActionListener.wrap(response -> { + // whether deleted state doc or not, continue as state doc may not exist + deleteConfigDoc(configId, listener); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + deleteConfigDoc(configId, listener); + } else { + LOG.error("Failed to delete state", exception); + listener.onFailure(exception); + } + })); + } + + private void deleteConfigDoc(String configId, ActionListener listener) { + LOG.info("Delete config {}", configId); + DeleteRequest deleteRequest = new DeleteRequest(CommonName.CONFIG_INDEX, configId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, new ActionListener() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + listener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + + private void getJob(String configId, ActionListener listener, ExecutorFunction function) { + if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { + GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(configId); + client.get(request, ActionListener.wrap(response -> onGetJobResponseForWrite(response, listener, function), exception -> { + LOG.error("Fail to get job: " + configId, exception); + listener.onFailure(exception); + })); + } else { + function.execute(); + } + } + + private void onGetJobResponseForWrite(GetResponse response, ActionListener listener, ExecutorFunction function) + throws IOException { + if (response.isExists()) { + String jobId = response.getId(); + if (jobId != null) { + // check if job is running on the config, if yes, we can't delete the config + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job adJob = Job.parse(parser); + if (adJob.isEnabled()) { + listener.onFailure(new OpenSearchStatusException("job is running: " + jobId, RestStatus.BAD_REQUEST)); + } else { + function.execute(); + } + } catch (IOException e) { + String message = "Failed to parse job " + jobId; + LOG.error(message, e); + function.execute(); + } + } + } else { + function.execute(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java similarity index 52% rename from src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java index 10aa64725..8a638e401 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java @@ -1,15 +1,9 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -19,44 +13,44 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.task.TaskCacheManager; import org.opensearch.transport.TransportService; -public class DeleteModelTransportAction extends - TransportNodesAction { - private static final Logger LOG = LogManager.getLogger(DeleteModelTransportAction.class); +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class BaseDeleteModelTransportAction, CacheProviderType extends CacheProvider, TaskCacheManagerType extends TaskCacheManager, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart> + extends TransportNodesAction { + + private static final Logger LOG = LogManager.getLogger(BaseDeleteModelTransportAction.class); private NodeStateManager nodeStateManager; - private ModelManager modelManager; - private FeatureManager featureManager; - private CacheProvider cache; - private ADTaskCacheManager adTaskCacheManager; - private EntityColdStarter coldStarter; - - @Inject - public DeleteModelTransportAction( + private CacheProviderType cache; + private TaskCacheManagerType adTaskCacheManager; + private ModelColdStartType coldStarter; + + public BaseDeleteModelTransportAction( ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, NodeStateManager nodeStateManager, - ModelManager modelManager, - FeatureManager featureManager, - CacheProvider cache, - ADTaskCacheManager adTaskCacheManager, - EntityColdStarter coldStarter + CacheProviderType cache, + TaskCacheManagerType taskCacheManager, + ModelColdStartType coldStarter, + String deleteModelAction ) { super( - DeleteModelAction.NAME, + deleteModelAction, threadPool, clusterService, transportService, @@ -67,10 +61,8 @@ public DeleteModelTransportAction( DeleteModelNodeResponse.class ); this.nodeStateManager = nodeStateManager; - this.modelManager = modelManager; - this.featureManager = featureManager; this.cache = cache; - this.adTaskCacheManager = adTaskCacheManager; + this.adTaskCacheManager = taskCacheManager; this.coldStarter = coldStarter; } @@ -104,34 +96,18 @@ protected DeleteModelNodeResponse newNodeResponse(StreamInput in) throws IOExcep @Override protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) { - String adID = request.getAdID(); - LOG.info("Delete model for {}", adID); - // delete in-memory models and model checkpoint - modelManager - .clear( - adID, - ActionListener - .wrap( - r -> LOG.info("Deleted model for [{}] with response [{}] ", adID, r), - e -> LOG.error("Fail to delete model for " + adID, e) - ) - ); + String configID = request.getConfigID(); + LOG.info("Delete model for {}", configID); + nodeStateManager.clear(configID); - // delete buffered shingle data - featureManager.clear(adID); + cache.get().clear(configID); - // delete transport state - nodeStateManager.clear(adID); - - cache.get().clear(adID); - - coldStarter.clear(adID); + coldStarter.clear(configID); // delete realtime task cache - adTaskCacheManager.removeRealtimeTaskCache(adID); + adTaskCacheManager.removeRealtimeTaskCache(configID); - LOG.info("Finished deleting {}", adID); + LOG.info("Finished deleting {}", configID); return new DeleteModelNodeResponse(clusterService.localNode()); } - } diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseEntityProfileTransportAction.java similarity index 70% rename from src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/BaseEntityProfileTransportAction.java index fedfb2aa7..68bea6e1b 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseEntityProfileTransportAction.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Locale; @@ -20,34 +20,37 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + /** * Transport action to get entity profile. */ -public class EntityProfileTransportAction extends HandledTransportAction { +public class BaseEntityProfileTransportAction, CacheProviderType extends CacheProvider> + extends HandledTransportAction { - private static final Logger LOG = LogManager.getLogger(EntityProfileTransportAction.class); + private static final Logger LOG = LogManager.getLogger(BaseEntityProfileTransportAction.class); public static final String NO_NODE_FOUND_MSG = "Cannot find model hosting node"; public static final String NO_MODEL_ID_FOUND_MSG = "Cannot find model id"; static final String FAIL_TO_GET_ENTITY_PROFILE_MSG = "Cannot get entity profile info"; @@ -56,33 +59,36 @@ public class EntityProfileTransportAction extends HandledTransportAction requestTimeOut ) { - super(EntityProfileAction.NAME, transportService, actionFilters, EntityProfileRequest::new); + super(entityProfileAction, transportService, actionFilters, EntityProfileRequest::new); this.transportService = transportService; this.hashRing = hashRing; this.option = TransportRequestOptions .builder() .withType(TransportRequestOptions.Type.REG) - .withTimeout(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings)) + .withTimeout(requestTimeOut.get(settings)) .build(); this.clusterService = clusterService; this.cacheProvider = cacheProvider; + this.entityProfileAction = entityProfileAction; } @Override protected void doExecute(Task task, EntityProfileRequest request, ActionListener listener) { - String adID = request.getAdID(); + String adID = request.getConfigID(); Entity entityValue = request.getEntityValue(); Optional modelIdOptional = entityValue.getModelId(adID); if (false == modelIdOptional.isPresent()) { @@ -91,7 +97,7 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener } // we use entity's toString (e.g., app_0) to find its node // This should be consistent with how we land a model node in AnomalyResultTransportAction - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(entityValue.toString()); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime(entityValue.toString()); if (false == node.isPresent()) { listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); return; @@ -100,12 +106,12 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener String modelId = modelIdOptional.get(); DiscoveryNode localNode = clusterService.localNode(); if (localNode.getId().equals(nodeId)) { - EntityCache cache = cacheProvider.get(); + CacheType cache = cacheProvider.get(); Set profilesToCollect = request.getProfilesToCollect(); EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); if (profilesToCollect.contains(EntityProfileName.ENTITY_INFO)) { builder.setActive(cache.isActive(adID, modelId)); - builder.setLastActiveMs(cache.getLastActiveMs(adID, modelId)); + builder.setLastActiveMs(cache.getLastActiveTime(adID, modelId)); } if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS) || profilesToCollect.contains(EntityProfileName.STATE)) { builder.setTotalUpdates(cache.getTotalUpdates(adID, modelId)); @@ -126,35 +132,29 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener try { transportService - .sendRequest( - node.get(), - EntityProfileAction.NAME, - request, - option, - new TransportResponseHandler() { - - @Override - public EntityProfileResponse read(StreamInput in) throws IOException { - return new EntityProfileResponse(in); - } - - @Override - public void handleResponse(EntityProfileResponse response) { - listener.onResponse(response); - } - - @Override - public void handleException(TransportException exp) { - listener.onFailure(exp); - } - - @Override - public String executor() { - return ThreadPool.Names.SAME; - } + .sendRequest(node.get(), entityProfileAction, request, option, new TransportResponseHandler() { + + @Override + public EntityProfileResponse read(StreamInput in) throws IOException { + return new EntityProfileResponse(in); + } + + @Override + public void handleResponse(EntityProfileResponse response) { + listener.onResponse(response); + } + + @Override + public void handleException(TransportException exp) { + listener.onFailure(exp); + } + @Override + public String executor() { + return ThreadPool.Names.SAME; } - ); + + }); } catch (Exception e) { LOG.error(String.format(Locale.ROOT, "Fail to get entity profile for detector {}, entity {}", adID, entityValue), e); listener.onFailure(new TimeSeriesException(adID, FAIL_TO_GET_ENTITY_PROFILE_MSG, e)); diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java new file mode 100644 index 000000000..8843d4400 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java @@ -0,0 +1,521 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_GET_FORECASTER; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.EntityProfileRunner; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.ProfileRunner; +import org.opensearch.timeseries.TaskProfile; +import org.opensearch.timeseries.TaskProfileRunner; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public abstract class BaseGetConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager, ConfigType extends Config, EntityProfileActionType extends ActionType, EntityProfileRunnerType extends EntityProfileRunner, TaskProfileType extends TaskProfile, ConfigProfileType extends ConfigProfile, ProfileActionType extends ActionType, TaskProfileRunnerType extends TaskProfileRunner, ProfileRunnerType extends ProfileRunner> + extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(BaseGetConfigTransportAction.class); + + protected final ClusterService clusterService; + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final Set allProfileTypeStrs; + protected final Set allProfileTypes; + protected final Set defaultDetectorProfileTypes; + protected final Set allEntityProfileTypeStrs; + protected final Set allEntityProfileTypes; + protected final Set defaultEntityProfileTypes; + protected final NamedXContentRegistry xContentRegistry; + protected final DiscoveryNodeFilterer nodeFilter; + protected final TransportService transportService; + protected volatile Boolean filterByEnabled; + protected final TaskManagerType taskManager; + private final Class configTypeClass; + private final String configParseFieldName; + private final List allTaskTypes; + private final String singleStreamRealTimeTaskName; + private final String hcRealTImeTaskName; + private final String singleStreamHistoricalTaskname; + private final String hcHistoricalTaskName; + private final TaskProfileRunnerType taskProfileRunner; + + public BaseGetConfigTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + Settings settings, + NamedXContentRegistry xContentRegistry, + TaskManagerType forecastTaskManager, + String getConfigAction, + Class configTypeClass, + String configParseFieldName, + List allTaskTypes, + String hcRealTImeTaskName, + String singleStreamRealTimeTaskName, + String hcHistoricalTaskName, + String singleStreamHistoricalTaskname, + Setting filterByBackendRoleEnableSetting, + TaskProfileRunnerType taskProfileRunner + ) { + super(getConfigAction, transportService, actionFilters, GetConfigRequest::new); + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + + List allProfiles = Arrays.asList(ProfileName.values()); + this.allProfileTypes = EnumSet.copyOf(allProfiles); + this.allProfileTypeStrs = getProfileListStrs(allProfiles); + List defaultProfiles = Arrays.asList(ProfileName.ERROR, ProfileName.STATE); + this.defaultDetectorProfileTypes = new HashSet<>(defaultProfiles); + + List allEntityProfiles = Arrays.asList(EntityProfileName.values()); + this.allEntityProfileTypes = EnumSet.copyOf(allEntityProfiles); + this.allEntityProfileTypeStrs = getProfileListStrs(allEntityProfiles); + List defaultEntityProfiles = Arrays.asList(EntityProfileName.STATE); + this.defaultEntityProfileTypes = new HashSet<>(defaultEntityProfiles); + + this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; + filterByEnabled = filterByBackendRoleEnableSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleEnableSetting, it -> filterByEnabled = it); + this.transportService = transportService; + this.taskManager = forecastTaskManager; + this.configTypeClass = configTypeClass; + this.configParseFieldName = configParseFieldName; + this.allTaskTypes = allTaskTypes; + this.hcRealTImeTaskName = hcRealTImeTaskName; + this.singleStreamRealTimeTaskName = singleStreamRealTimeTaskName; + this.hcHistoricalTaskName = hcHistoricalTaskName; + this.singleStreamHistoricalTaskname = singleStreamHistoricalTaskname; + this.taskProfileRunner = taskProfileRunner; + } + + @Override + public void doExecute(Task task, GetConfigRequest request, ActionListener actionListener) { + String configID = request.getConfigID(); + User user = ParseUtils.getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_FORECASTER); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configID, + filterByEnabled, + listener, + (config) -> getExecute(request, listener), + client, + clusterService, + xContentRegistry, + configTypeClass + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + protected void getConfigAndJob( + String configID, + boolean returnJob, + boolean returnTask, + Optional realtimeConfigTask, + Optional historicalConfigTask, + ActionListener listener + ) { + MultiGetRequest.Item configItem = new MultiGetRequest.Item(CommonName.CONFIG_INDEX, configID); + MultiGetRequest multiGetRequest = new MultiGetRequest().add(configItem); + if (returnJob) { + MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(CommonName.JOB_INDEX, configID); + multiGetRequest.add(adJobItem); + } + client + .multiGet( + multiGetRequest, + onMultiGetResponse(listener, returnJob, returnTask, realtimeConfigTask, historicalConfigTask, configID) + ); + } + + public void getExecute(GetConfigRequest request, ActionListener listener) { + String configID = request.getConfigID(); + String typesStr = request.getTypeStr(); + String rawPath = request.getRawPath(); + Entity entity = request.getEntity(); + boolean all = request.isAll(); + boolean returnJob = request.isReturnJob(); + boolean returnTask = request.isReturnTask(); + + try { + if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { + getExecuteProfile(request, entity, typesStr, all, configID, listener); + } else { + if (returnTask) { + taskManager.getAndExecuteOnLatestTasks(configID, null, null, allTaskTypes, (taskList) -> { + Optional realtimeTask = Optional.empty(); + Optional historicalTask = Optional.empty(); + if (taskList != null && taskList.size() > 0) { + Map tasks = new HashMap<>(); + List duplicateTasks = new ArrayList<>(); + for (TaskClass task : taskList) { + if (tasks.containsKey(task.getTaskType())) { + LOG + .info( + "Found duplicate latest task of config {}, task id: {}, task type: {}", + configID, + task.getTaskType(), + task.getTaskId() + ); + duplicateTasks.add(task); + continue; + } + tasks.put(task.getTaskType(), task); + } + if (duplicateTasks.size() > 0) { + taskManager.resetLatestFlagAsFalse(duplicateTasks); + } + + if (tasks.containsKey(hcRealTImeTaskName)) { + realtimeTask = Optional.ofNullable(tasks.get(hcRealTImeTaskName)); + } else if (tasks.containsKey(singleStreamRealTimeTaskName)) { + realtimeTask = Optional.ofNullable(tasks.get(singleStreamRealTimeTaskName)); + } + if (tasks.containsKey(hcHistoricalTaskName)) { + historicalTask = Optional.ofNullable(tasks.get(hcHistoricalTaskName)); + } else if (tasks.containsKey(singleStreamHistoricalTaskname)) { + historicalTask = Optional.ofNullable(tasks.get(singleStreamHistoricalTaskname)); + } else { + // AD needs to provides custom behavior for bwc, while forecasting can inherit + // the empty implementation + fillInHistoricalTaskforBwc(tasks, historicalTask); + } + } + getConfigAndJob(configID, returnJob, returnTask, realtimeTask, historicalTask, listener); + }, transportService, true, 2, listener); + } else { + getConfigAndJob(configID, returnJob, returnTask, Optional.empty(), Optional.empty(), listener); + } + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private ActionListener onMultiGetResponse( + ActionListener listener, + boolean returnJob, + boolean returnTask, + Optional realtimeTask, + Optional historicalTask, + String configId + ) { + return new ActionListener() { + @Override + public void onResponse(MultiGetResponse multiGetResponse) { + MultiGetItemResponse[] responses = multiGetResponse.getResponses(); + ConfigType config = null; + Job job = null; + String id = null; + long version = 0; + long seqNo = 0; + long primaryTerm = 0; + + for (MultiGetItemResponse response : responses) { + if (CommonName.CONFIG_INDEX.equals(response.getIndex())) { + if (response.getResponse() == null || !response.getResponse().isExists()) { + listener + .onFailure( + new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND) + ); + return; + } + id = response.getId(); + version = response.getResponse().getVersion(); + primaryTerm = response.getResponse().getPrimaryTerm(); + seqNo = response.getResponse().getSeqNo(); + if (!response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + config = parser.namedObject(configTypeClass, configParseFieldName, null); + } catch (Exception e) { + String message = "Failed to parse config " + configId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } else if (CommonName.JOB_INDEX.equals(response.getIndex())) { + if (response.getResponse() != null + && response.getResponse().isExists() + && !response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + job = Job.parse(parser); + } catch (Exception e) { + String message = "Failed to parse job " + configId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } + } + listener + .onResponse( + createResponse( + version, + id, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask, + historicalTask, + returnTask, + RestStatus.OK, + null, + null, + false + ) + ); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }; + } + + protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) {} + + protected void getExecuteProfile( + GetConfigRequest request, + Entity entity, + String typesStr, + boolean all, + String configId, + ActionListener listener + ) { + if (entity != null) { + Set entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); + EntityProfileRunnerType profileRunner = createEntityProfileRunner( + client, + clientUtil, + xContentRegistry, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); + profileRunner.profile(configId, entity, entityProfilesToCollect, ActionListener.wrap(profile -> { + listener + .onResponse( + createResponse( + 0, + null, + 0, + 0, + null, + null, + false, + Optional.empty(), + Optional.empty(), + false, + null, + null, + profile, + true + ) + ); + }, e -> listener.onFailure(e))); + } else { + Set profilesToCollect = getProfilesToCollect(typesStr, all); + ProfileRunnerType profileRunner = createProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + TimeSeriesSettings.NUM_MIN_SAMPLES, + transportService, + taskManager, + taskProfileRunner + ); + profileRunner.profile(configId, getProfileActionListener(listener), profilesToCollect); + } + + } + + protected abstract GetConfigResponseType createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + ConfigType config, + Job job, + boolean returnJob, + Optional realtimeTask, + Optional historicalTask, + boolean returnTask, + RestStatus restStatus, + ConfigProfileType detectorProfile, + EntityProfile entityProfile, + boolean profileResponse + ); + + protected OpenSearchStatusException buildInternalServerErrorResponse(Exception e, String errorMsg) { + LOG.error(errorMsg, e); + return new OpenSearchStatusException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR); + } + + private Set getProfileListStrs(List profileList) { + return profileList.stream().map(profile -> profile.getName()).collect(Collectors.toSet()); + } + + /** + * + * @param typesStr a list of input profile types separated by comma + * @param all whether we should return all profile in the response + * @return profiles to collect for an entity + */ + protected Set getEntityProfilesToCollect(String typesStr, boolean all) { + if (all) { + return this.allEntityProfileTypes; + } else if (Strings.isEmpty(typesStr)) { + return this.defaultEntityProfileTypes; + } else { + // Filter out unsupported types + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return EntityProfileName.getNames(Sets.intersection(allEntityProfileTypeStrs, typesInRequest)); + } + } + + /** + * + * @param typesStr a list of input profile types separated by comma + * @param all whether we should return all profile in the response + * @return profiles to collect for a detector + */ + protected Set getProfilesToCollect(String typesStr, boolean all) { + if (all) { + return this.allProfileTypes; + } else if (Strings.isEmpty(typesStr)) { + return this.defaultDetectorProfileTypes; + } else { + // Filter out unsupported types + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return ProfileName.getNames(Sets.intersection(allProfileTypeStrs, typesInRequest)); + } + } + + protected ActionListener getProfileActionListener(ActionListener listener) { + return ActionListener.wrap(new CheckedConsumer() { + @Override + public void accept(ConfigProfileType profile) throws Exception { + listener + .onResponse( + createResponse( + 0, + null, + 0, + 0, + null, + null, + false, + Optional.empty(), + Optional.empty(), + false, + null, + profile, + null, + true + ) + ); + } + }, exception -> { listener.onFailure(exception); }); + } + + protected abstract EntityProfileRunnerType createEntityProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + long requiredSamples + ); + + protected abstract ProfileRunnerType createProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + TaskManagerType taskManager, + TaskProfileRunnerType taskProfileRunner + ); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java new file mode 100644 index 000000000..99f4a69b3 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java @@ -0,0 +1,133 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public abstract class BaseJobTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ProfileActionType extends ActionType, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder, IndexJobActionHandlerType extends IndexJobActionHandler> + extends HandledTransportAction { + private final Logger logger = LogManager.getLogger(BaseJobTransportAction.class); + + private final Client client; + private final ClusterService clusterService; + private final Settings settings; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private final TransportService transportService; + private final Setting requestTimeOutSetting; + private final String failtoStartMsg; + private final String failtoStopMsg; + private final Class configClass; + private final IndexJobActionHandlerType indexJobActionHandlerType; + + public BaseJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + Setting filterByBackendRoleSettng, + String jobActionName, + Setting requestTimeOutSetting, + String failtoStartMsg, + String failtoStopMsg, + Class configClass, + IndexJobActionHandlerType indexJobActionHandlerType + ) { + super(jobActionName, transportService, actionFilters, JobRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.settings = settings; + this.xContentRegistry = xContentRegistry; + filterByEnabled = filterByBackendRoleSettng.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSettng, it -> filterByEnabled = it); + this.requestTimeOutSetting = requestTimeOutSetting; + this.failtoStartMsg = failtoStartMsg; + this.failtoStopMsg = failtoStopMsg; + this.configClass = configClass; + this.indexJobActionHandlerType = indexJobActionHandlerType; + } + + @Override + protected void doExecute(Task task, JobRequest request, ActionListener actionListener) { + String configId = request.getConfigID(); + DateRange dateRange = request.getDateRange(); + boolean historical = request.isHistorical(); + String rawPath = request.getRawPath(); + TimeValue requestTimeout = requestTimeOutSetting.get(settings); + String errorMessage = rawPath.endsWith(RestHandlerUtils.START_JOB) ? failtoStartMsg : failtoStopMsg; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + + // By the time request reaches here, the user permissions are validated by Security plugin. + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configId, + filterByEnabled, + listener, + (config) -> executeConfig(listener, configId, dateRange, historical, rawPath, requestTimeout, user, context), + client, + clusterService, + xContentRegistry, + configClass + ); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void executeConfig( + ActionListener listener, + String configId, + DateRange dateRange, + boolean historical, + String rawPath, + TimeValue requestTimeout, + User user, + ThreadContext.StoredContext context + ) { + if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { + indexJobActionHandlerType.startConfig(configId, dateRange, user, transportService, context, listener); + } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { + indexJobActionHandlerType.stopConfig(configId, historical, user, transportService, listener); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseProfileTransportAction.java similarity index 50% rename from src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/BaseProfileTransportAction.java index af1bbed50..398e03994 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseProfileTransportAction.java @@ -9,9 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -23,26 +21,26 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + /** * This class contains the logic to extract the stats from the nodes */ -public class ProfileTransportAction extends TransportNodesAction { - private static final Logger LOG = LogManager.getLogger(ProfileTransportAction.class); - private ModelManager modelManager; - private FeatureManager featureManager; - private CacheProvider cacheProvider; +public class BaseProfileTransportAction, CacheProviderType extends CacheProvider> + extends TransportNodesAction { + private static final Logger LOG = LogManager.getLogger(BaseProfileTransportAction.class); + private CacheProviderType cacheProvider; // the number of models to return. Defaults to 10. private volatile int numModelsToReturn; @@ -53,24 +51,22 @@ public class ProfileTransportAction extends TransportNodesAction maxModelNumberPerNode ) { super( - ProfileAction.NAME, + profileAction, threadPool, clusterService, transportService, @@ -80,11 +76,9 @@ public ProfileTransportAction( ThreadPool.Names.MANAGEMENT, ProfileNodeResponse.class ); - this.modelManager = modelManager; - this.featureManager = featureManager; this.cacheProvider = cacheProvider; - this.numModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); + this.numModelsToReturn = maxModelNumberPerNode.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxModelNumberPerNode, it -> this.numModelsToReturn = it); } @Override @@ -104,41 +98,33 @@ protected ProfileNodeResponse newNodeResponse(StreamInput in) throws IOException @Override protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { - String detectorId = request.getId(); - Set profiles = request.getProfilesToBeRetrieved(); + String configId = request.getConfigId(); + Set profiles = request.getProfilesToBeRetrieved(); int shingleSize = -1; long activeEntity = 0; long totalUpdates = 0; Map modelSize = null; List modelProfiles = null; int modelCount = 0; - if (request.isForMultiEntityDetector()) { - if (profiles.contains(DetectorProfileName.ACTIVE_ENTITIES)) { - activeEntity = cacheProvider.get().getActiveEntities(detectorId); - } - if (profiles.contains(DetectorProfileName.INIT_PROGRESS)) { - totalUpdates = cacheProvider.get().getTotalUpdates(detectorId);// get toal updates - } - if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES)) { - modelSize = cacheProvider.get().getModelSize(detectorId); - } - // need to provide entity info for HCAD - if (profiles.contains(DetectorProfileName.MODELS)) { - modelProfiles = cacheProvider.get().getAllModelProfile(detectorId); - modelCount = modelProfiles.size(); - int limit = Math.min(numModelsToReturn, modelCount); - if (limit != modelCount) { - LOG.info("model number limit reached"); - modelProfiles = modelProfiles.subList(0, limit); - } - } - } else { - if (profiles.contains(DetectorProfileName.COORDINATING_NODE) || profiles.contains(DetectorProfileName.SHINGLE_SIZE)) { - shingleSize = featureManager.getShingleSize(detectorId); - } + if (profiles.contains(ProfileName.ACTIVE_ENTITIES)) { + activeEntity = cacheProvider.get().getActiveEntities(configId); + } - if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) || profiles.contains(DetectorProfileName.MODELS)) { - modelSize = modelManager.getModelSize(detectorId); + // state profile requires totalUpdates as well + if (profiles.contains(ProfileName.INIT_PROGRESS) || profiles.contains(ProfileName.STATE)) { + totalUpdates = cacheProvider.get().getTotalUpdates(configId);// get toal updates + } + if (profiles.contains(ProfileName.TOTAL_SIZE_IN_BYTES)) { + modelSize = cacheProvider.get().getModelSize(configId); + } + // need to provide entity info for HCAD + if (profiles.contains(ProfileName.MODELS)) { + modelProfiles = cacheProvider.get().getAllModelProfile(configId); + modelCount = modelProfiles.size(); + int limit = Math.min(numModelsToReturn, modelCount); + if (limit != modelCount) { + LOG.info("model number limit reached"); + modelProfiles = modelProfiles.subList(0, limit); } } diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseSearchConfigInfoTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseSearchConfigInfoTransportAction.java new file mode 100644 index 000000000..536bb1466 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseSearchConfigInfoTransportAction.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public abstract class BaseSearchConfigInfoTransportAction extends + HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(BaseSearchConfigInfoTransportAction.class); + private final Client client; + + public BaseSearchConfigInfoTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + String searchConfigActionName + ) { + super(searchConfigActionName, transportService, actionFilters, SearchConfigInfoRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, SearchConfigInfoRequest request, ActionListener actionListener) { + String name = request.getName(); + String rawPath = request.getRawPath(); + ActionListener listener = wrapRestActionListener(actionListener, CommonMessages.FAIL_TO_GET_CONFIG_INFO); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + SearchRequest searchRequest = new SearchRequest().indices(CommonName.CONFIG_INDEX); + if (rawPath.endsWith(RestHandlerUtils.COUNT)) { + // Count detectors + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + client.search(searchRequest, new ActionListener() { + + @Override + public void onResponse(SearchResponse searchResponse) { + SearchConfigInfoResponse response = new SearchConfigInfoResponse( + searchResponse.getHits().getTotalHits().value, + false + ); + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (e.getClass() == IndexNotFoundException.class) { + // Anomaly Detectors index does not exist + // Could be that user is creating first detector + SearchConfigInfoResponse response = new SearchConfigInfoResponse(0, false); + listener.onResponse(response); + } else { + listener.onFailure(e); + } + } + }); + } else { + // Match name with existing detectors + TermsQueryBuilder query = QueryBuilders.termsQuery("name.keyword", name); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + searchRequest.source(searchSourceBuilder); + client.search(searchRequest, new ActionListener() { + + @Override + public void onResponse(SearchResponse searchResponse) { + boolean nameExists = false; + nameExists = searchResponse.getHits().getTotalHits().value > 0; + SearchConfigInfoResponse response = new SearchConfigInfoResponse(0, nameExists); + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (e.getClass() == IndexNotFoundException.class) { + // Anomaly Detectors index does not exist + // Could be that user is creating first detector + SearchConfigInfoResponse response = new SearchConfigInfoResponse(0, false); + listener.onResponse(response); + } else { + listener.onFailure(e); + } + } + }); + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseStatsNodesTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseStatsNodesTransportAction.java new file mode 100644 index 000000000..ac43684ba --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseStatsNodesTransportAction.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.transport.TransportService; + +public class BaseStatsNodesTransportAction extends + TransportNodesAction { + + private Stats stats; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param stats TimeSeriesStats object + */ + public BaseStatsNodesTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Stats stats, + String statsNodesActionName + ) { + super( + statsNodesActionName, + threadPool, + clusterService, + transportService, + actionFilters, + StatsRequest::new, + StatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + StatsNodeResponse.class + ); + this.stats = stats; + } + + @Override + protected StatsNodesResponse newResponse(StatsRequest request, List responses, List failures) { + return new StatsNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected StatsNodeRequest newNodeRequest(StatsRequest request) { + return new StatsNodeRequest(request); + } + + @Override + protected StatsNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new StatsNodeResponse(in); + } + + @Override + protected StatsNodeResponse nodeOperation(StatsNodeRequest request) { + return createADStatsNodeResponse(request.getADStatsRequest()); + } + + protected StatsNodeResponse createADStatsNodeResponse(StatsRequest statsRequest) { + Map statValues = new HashMap<>(); + Set statsToBeRetrieved = statsRequest.getStatsToBeRetrieved(); + + for (String statName : stats.getNodeStats().keySet()) { + if (statsToBeRetrieved.contains(statName)) { + statValues.put(statName, stats.getStats().get(statName).getValue()); + } + } + + return new StatsNodeResponse(clusterService.localNode(), statValues); + } + +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseStatsTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseStatsTransportAction.java new file mode 100644 index 000000000..72c533e2e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseStatsTransportAction.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.transport.TransportService; + +public abstract class BaseStatsTransportAction extends HandledTransportAction { + public final Logger logger = LogManager.getLogger(BaseStatsTransportAction.class); + + protected final Client client; + protected final Stats stats; + protected final ClusterService clusterService; + + public BaseStatsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + Stats stats, + ClusterService clusterService, + String statsAction + + ) { + super(statsAction, transportService, actionFilters, StatsRequest::new); + this.client = client; + this.stats = stats; + this.clusterService = clusterService; + } + + @Override + protected void doExecute(Task task, StatsRequest request, ActionListener actionListener) { + ActionListener listener = wrapRestActionListener(actionListener, CommonMessages.FAIL_TO_GET_STATS); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + getStats(client, listener, request); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + /** + * Make the 2 requests to get the node and cluster statistics + * + * @param client Client + * @param listener Listener to send response + * @param statsRequest Request containing stats to be retrieved + */ + public void getStats(Client client, ActionListener listener, StatsRequest statsRequest) { + // Use MultiResponsesDelegateActionListener to execute 2 async requests and create the response once they finish + MultiResponsesDelegateActionListener delegateListener = new MultiResponsesDelegateActionListener<>( + getRestStatsListener(listener), + 2, + "Unable to return Stats", + false + ); + + getClusterStats(client, delegateListener, statsRequest); + getNodeStats(client, delegateListener, statsRequest); + } + + /** + * Listener sends response once Node Stats and Cluster Stats are gathered + * + * @param listener Listener to send response + * @return ActionListener for StatsResponse + */ + public ActionListener getRestStatsListener(ActionListener listener) { + return ActionListener + .wrap( + statsResponse -> { listener.onResponse(new StatsTimeSeriesResponse(statsResponse)); }, + exception -> listener.onFailure(new OpenSearchStatusException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)) + ); + } + + /** + * Collect Cluster Stats into map to be retrieved + * + * @param statsRequest Request containing stats to be retrieved + * @return Map containing Cluster Stats + */ + protected Map getClusterStatsMap(StatsRequest statsRequest) { + Map clusterStats = new HashMap<>(); + Set statsToBeRetrieved = statsRequest.getStatsToBeRetrieved(); + stats + .getClusterStats() + .entrySet() + .stream() + .filter(s -> statsToBeRetrieved.contains(s.getKey())) + .forEach(s -> clusterStats.put(s.getKey(), s.getValue().getValue())); + return clusterStats; + } + + protected abstract void getClusterStats( + Client client, + MultiResponsesDelegateActionListener listener, + StatsRequest adStatsRequest + ); + + protected abstract void getNodeStats( + Client client, + MultiResponsesDelegateActionListener listener, + StatsRequest adStatsRequest + ); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseSuggestConfigParamTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseSuggestConfigParamTransportAction.java new file mode 100644 index 000000000..4bd231686 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseSuggestConfigParamTransportAction.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; + +import java.time.Clock; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.rest.handler.IntervalCalculation; +import org.opensearch.timeseries.rest.handler.LatestTimeRetriever; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public abstract class BaseSuggestConfigParamTransportAction extends + HandledTransportAction { + public static final Logger logger = LogManager.getLogger(BaseSuggestConfigParamTransportAction.class); + + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final SearchFeatureDao searchFeatureDao; + protected volatile Boolean filterByEnabled; + protected Clock clock; + protected AnalysisType context; + + public BaseSuggestConfigParamTransportAction( + String actionName, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ActionFilters actionFilters, + TransportService transportService, + Setting filterByBackendRoleSetting, + AnalysisType context, + SearchFeatureDao searchFeatureDao + ) { + super(actionName, transportService, actionFilters, SuggestConfigParamRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.filterByEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByEnabled = it); + this.clock = Clock.systemUTC(); + this.context = context; + this.searchFeatureDao = searchFeatureDao; + } + + @Override + protected void doExecute(Task task, SuggestConfigParamRequest request, ActionListener listener) { + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute(user, listener, () -> suggestExecute(request, user, context, listener)); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + public void resolveUserAndExecute(User requestedUser, ActionListener listener, ExecutorFunction function) { + try { + // Check if user has backend roles + // When filter by is enabled, block users validating detectors who do not have + // backend roles. + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new TimeSeriesException(error)); + return; + } + } + // Validate analysis + function.execute(); + } catch (Exception e) { + listener.onFailure(e); + } + } + + public void suggestExecute( + SuggestConfigParamRequest request, + User user, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + storedContext.restore(); + Config config = request.getConfig(); + if (request.getParam().equals(Forecaster.FORECAST_INTERVAL_FIELD)) { + IntervalCalculation intervalCalculation = new IntervalCalculation( + config, + request.getRequestTimeout(), + client, + clientUtil, + user, + context, + clock + ); + LatestTimeRetriever latestTimeRetriever = new LatestTimeRetriever( + config, + request.getRequestTimeout(), + clientUtil, + client, + user, + context, + searchFeatureDao + ); + + ActionListener intervalSuggestionListener = ActionListener + .wrap(interval -> listener.onResponse(new SuggestConfigParamResponse(interval)), listener::onFailure); + ActionListener, Map>> latestTimeListener = ActionListener.wrap(latestEntityAttributes -> { + Optional latestTime = latestEntityAttributes.getLeft(); + if (latestTime.isPresent()) { + intervalCalculation.findInterval(latestTime.get(), latestEntityAttributes.getRight(), intervalSuggestionListener); + } else { + listener.onFailure(new TimeSeriesException("Empty data. Cannot find a good interval.")); + } + + }, exception -> { + listener.onFailure(exception); + logger.error("Failed to create search request for last data point", exception); + }); + + latestTimeRetriever.checkIfHC(latestTimeListener); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java new file mode 100644 index 000000000..9967a8218 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseValidateConfigTransportAction.java @@ -0,0 +1,228 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; + +import java.time.Clock; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.rest.handler.Processor; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public abstract class BaseValidateConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement> + extends HandledTransportAction { + public static final Logger logger = LogManager.getLogger(BaseValidateConfigTransportAction.class); + + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final ClusterService clusterService; + protected final NamedXContentRegistry xContentRegistry; + protected final IndexManagementType indexManagement; + protected final SearchFeatureDao searchFeatureDao; + protected volatile Boolean filterByEnabled; + protected Clock clock; + protected Settings settings; + + public BaseValidateConfigTransportAction( + String actionName, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Settings settings, + IndexManagementType indexManagement, + ActionFilters actionFilters, + TransportService transportService, + SearchFeatureDao searchFeatureDao, + Setting filterByBackendRoleSetting + ) { + super(actionName, transportService, actionFilters, ValidateConfigRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.indexManagement = indexManagement; + this.filterByEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByEnabled = it); + this.searchFeatureDao = searchFeatureDao; + this.clock = Clock.systemUTC(); + this.settings = settings; + } + + @Override + protected void doExecute(Task task, ValidateConfigRequest request, ActionListener listener) { + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute(user, listener, () -> validateExecute(request, user, context, listener)); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + public void resolveUserAndExecute(User requestedUser, ActionListener listener, ExecutorFunction function) { + try { + // Check if user has backend roles + // When filter by is enabled, block users validating detectors who do not have backend roles. + if (filterByEnabled) { + String error = checkFilterByBackendRoles(requestedUser); + if (error != null) { + listener.onFailure(new TimeSeriesException(error)); + return; + } + } + // Validate analysis + function.execute(); + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void checkIndicesAndExecute( + List indices, + ExecutorFunction function, + ActionListener listener + ) { + SearchRequest searchRequest = new SearchRequest() + .indices(indices.toArray(new String[0])) + .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); + client.search(searchRequest, ActionListener.wrap(r -> function.execute(), e -> { + if (e instanceof IndexNotFoundException) { + // IndexNotFoundException is converted to a ADValidationException that gets + // parsed to a DetectorValidationIssue that is returned to + // the user as a response indicating index doesn't exist + ConfigValidationIssue issue = parseValidationException( + new ValidationException(ADCommonMessages.INDEX_NOT_FOUND, ValidationIssueType.INDICES, ValidationAspect.DETECTOR) + ); + listener.onResponse(new ValidateConfigResponse(issue)); + return; + } + logger.error(e); + listener.onFailure(e); + })); + } + + protected Map getFeatureSubIssuesFromErrorMessage(String errorMessage) { + Map result = new HashMap<>(); + String[] subIssueMessagesSuffix = errorMessage.split(", "); + for (int i = 0; i < subIssueMessagesSuffix.length; i++) { + result.put(subIssueMessagesSuffix[i].split(": ")[1], subIssueMessagesSuffix[i].split(": ")[0]); + } + return result; + } + + public ConfigValidationIssue parseValidationException(ValidationException exception) { + String originalErrorMessage = exception.getMessage(); + String errorMessage = ""; + Map subIssues = null; + IntervalTimeConfiguration intervalSuggestion = exception.getIntervalSuggestion(); + switch (exception.getType()) { + case FEATURE_ATTRIBUTES: + int firstLeftBracketIndex = originalErrorMessage.indexOf("["); + int lastRightBracketIndex = originalErrorMessage.lastIndexOf("]"); + if (firstLeftBracketIndex != -1) { + // if feature issue messages are between square brackets like + // [Feature has issue: A, Feature has issue: B] + errorMessage = originalErrorMessage.substring(firstLeftBracketIndex + 1, lastRightBracketIndex); + subIssues = getFeatureSubIssuesFromErrorMessage(errorMessage); + } else { + // features having issue like over max feature limit, duplicate feature name, etc. + errorMessage = originalErrorMessage; + } + break; + case NAME: + case CATEGORY: + case DETECTION_INTERVAL: + case FILTER_QUERY: + case TIMEFIELD_FIELD: + case SHINGLE_SIZE_FIELD: + case WINDOW_DELAY: + case RESULT_INDEX: + case GENERAL_SETTINGS: + case AGGREGATION: + case TIMEOUT: + case INDICES: + case FORECAST_INTERVAL: + case IMPUTATION: + case HORIZON_SIZE: + case TRANSFORM_DECAY: + errorMessage = originalErrorMessage; + break; + } + return new ConfigValidationIssue(exception.getAspect(), exception.getType(), errorMessage, subIssues, intervalSuggestion); + } + + public void validateExecute( + ValidateConfigRequest request, + User user, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + storedContext.restore(); + Config detector = request.getConfig(); + ActionListener validateListener = ActionListener.wrap(response -> { + logger.debug("Result of validation process " + response); + // forcing response to be empty + listener.onResponse(new ValidateConfigResponse((ConfigValidationIssue) null)); + }, exception -> { + if (exception instanceof ValidationException) { + // ADValidationException is converted as validation issues returned as response to user + ConfigValidationIssue issue = parseValidationException((ValidationException) exception); + listener.onResponse(new ValidateConfigResponse(issue)); + return; + } + logger.error(exception); + listener.onFailure(exception); + }); + checkIndicesAndExecute(detector.getIndices(), () -> { + try { + createProcessor(detector, request, user).start(validateListener); + } catch (Exception exception) { + String errorMessage = String + .format(Locale.ROOT, "Unknown exception caught while validating detector %s", request.getConfig()); + logger.error(errorMessage, exception); + listener.onFailure(exception); + } + }, listener); + } + + protected abstract Processor createProcessor(Config detector, ValidateConfigRequest request, User user); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java new file mode 100644 index 000000000..c6b4f1285 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BooleanNodeResponse.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class BooleanNodeResponse extends BaseNodeResponse { + private final boolean answer; + + public BooleanNodeResponse(StreamInput in) throws IOException { + super(in); + answer = in.readBoolean(); + } + + public BooleanNodeResponse(DiscoveryNode node, boolean answer) { + super(node); + this.answer = answer; + } + + public boolean isAnswerTrue() { + return answer; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(answer); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java b/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java new file mode 100644 index 000000000..8eb18475a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BooleanResponse.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; + +public class BooleanResponse extends BaseNodesResponse implements ToXContentFragment { + private final boolean answer; + + public BooleanResponse(StreamInput in) throws IOException { + super(in); + answer = in.readBoolean(); + } + + public BooleanResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + this.answer = nodes.stream().anyMatch(response -> response.isAnswerTrue()); + ; + } + + public boolean isAnswerTrue() { + return answer; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(answer); + } + + @Override + protected List readNodesFrom(StreamInput in) throws IOException { + return in.readList(BooleanNodeResponse::new); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(CommonName.ANSWER_FIELD, answer); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/CronNodeRequest.java similarity index 93% rename from src/main/java/org/opensearch/ad/transport/CronNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/CronNodeRequest.java index a5362ff46..aef33bb3c 100644 --- a/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronNodeRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/CronNodeResponse.java similarity index 93% rename from src/main/java/org/opensearch/ad/transport/CronNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/CronNodeResponse.java index f1e9fb0e1..b83e049d3 100644 --- a/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronNodeResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -20,7 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; public class CronNodeResponse extends BaseNodeResponse implements ToXContentObject { - static String NODE_ID = "node_id"; + public static String NODE_ID = "node_id"; public CronNodeResponse(StreamInput in) throws IOException { super(in); diff --git a/src/main/java/org/opensearch/ad/transport/CronRequest.java b/src/main/java/org/opensearch/timeseries/transport/CronRequest.java similarity index 95% rename from src/main/java/org/opensearch/ad/transport/CronRequest.java rename to src/main/java/org/opensearch/timeseries/transport/CronRequest.java index 0f91ae676..9f1add649 100644 --- a/src/main/java/org/opensearch/ad/transport/CronRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/transport/CronResponse.java b/src/main/java/org/opensearch/timeseries/transport/CronResponse.java similarity index 94% rename from src/main/java/org/opensearch/ad/transport/CronResponse.java rename to src/main/java/org/opensearch/timeseries/transport/CronResponse.java index 13332c3af..56998f2cf 100644 --- a/src/main/java/org/opensearch/ad/transport/CronResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -23,7 +23,7 @@ import org.opensearch.core.xcontent.XContentBuilder; public class CronResponse extends BaseNodesResponse implements ToXContentFragment { - static String NODES_JSON_KEY = "nodes"; + public static String NODES_JSON_KEY = "nodes"; public CronResponse(StreamInput in) throws IOException { super(in); diff --git a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java similarity index 62% rename from src/main/java/org/opensearch/ad/transport/CronTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java index 82075d035..a7e451bfb 100644 --- a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/CronTransportAction.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -19,27 +19,34 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.CronAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.transport.TransportService; public class CronTransportAction extends TransportNodesAction { private final Logger LOG = LogManager.getLogger(CronTransportAction.class); private NodeStateManager transportStateManager; - private ModelManager modelManager; + private ADModelManager adModelManager; private FeatureManager featureManager; - private CacheProvider cacheProvider; - private EntityColdStarter entityColdStarter; + private ADCacheProvider adCacheProvider; + private ForecastCacheProvider forecastCacheProvider; + private ADEntityColdStart adEntityColdStarter; + private ForecastColdStart forecastColdStarter; private ADTaskManager adTaskManager; + private ForecastTaskManager forecastTaskManager; @Inject public CronTransportAction( @@ -48,11 +55,14 @@ public CronTransportAction( TransportService transportService, ActionFilters actionFilters, NodeStateManager tarnsportStatemanager, - ModelManager modelManager, + ADModelManager adModelManager, FeatureManager featureManager, - CacheProvider cacheProvider, - EntityColdStarter entityColdStarter, - ADTaskManager adTaskManager + ADCacheProvider adCacheProvider, + ForecastCacheProvider forecastCacheProvider, + ADEntityColdStart adEntityColdStarter, + ForecastColdStart forecastColdStarter, + ADTaskManager adTaskManager, + ForecastTaskManager forecastTaskManager ) { super( CronAction.NAME, @@ -66,11 +76,14 @@ public CronTransportAction( CronNodeResponse.class ); this.transportStateManager = tarnsportStatemanager; - this.modelManager = modelManager; + this.adModelManager = adModelManager; this.featureManager = featureManager; - this.cacheProvider = cacheProvider; - this.entityColdStarter = entityColdStarter; + this.adCacheProvider = adCacheProvider; + this.forecastCacheProvider = forecastCacheProvider; + this.adEntityColdStarter = adEntityColdStarter; + this.forecastColdStarter = forecastColdStarter; this.adTaskManager = adTaskManager; + this.forecastTaskManager = forecastTaskManager; } @Override @@ -97,27 +110,27 @@ protected CronNodeResponse newNodeResponse(StreamInput in) throws IOException { */ @Override protected CronNodeResponse nodeOperation(CronNodeRequest request) { - LOG.info("Start running AD hourly cron."); + LOG.info("Start running hourly cron."); + // ====================== + // AD + // ====================== // makes checkpoints for hosted models and stop hosting models not actively // used. // for single-entity detector - modelManager - .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining model", e))); + adModelManager + .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining ad model", e))); // for multi-entity detector - cacheProvider.get().maintenance(); + adCacheProvider.get().maintenance(); // delete unused buffered shingle data featureManager.maintenance(); - // delete unused transport state - transportStateManager.maintenance(); - - entityColdStarter.maintenance(); + adEntityColdStarter.maintenance(); // clean child tasks and AD results of deleted detector level task - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); // clean AD results of deleted detector - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); // maintain running historical tasks: reset task state as stopped if not running and clean stale running entities adTaskManager.maintainRunningHistoricalTasks(transportService, 100); @@ -125,6 +138,22 @@ protected CronNodeResponse nodeOperation(CronNodeRequest request) { // maintain running realtime tasks: clean stale running realtime task cache adTaskManager.maintainRunningRealtimeTasks(); + // ====================== + // Forecast + // ====================== + forecastCacheProvider.get().maintenance(); + forecastColdStarter.maintenance(); + // clean child tasks and forecast results of deleted forecaster level task + forecastTaskManager.cleanChildTasksAndResultsOfDeletedTask(); + forecastTaskManager.cleanResultOfDeletedConfig(); + forecastTaskManager.maintainRunningRealtimeTasks(); + + // ====================== + // Common + // ====================== + // delete unused transport state + transportStateManager.maintenance(); + return new CronNodeResponse(clusterService.localNode()); } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteConfigRequest.java similarity index 60% rename from src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteConfigRequest.java index f87b6e0a1..93980ce83 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,40 +17,40 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.constant.CommonMessages; -public class DeleteAnomalyDetectorRequest extends ActionRequest { +public class DeleteConfigRequest extends ActionRequest { - private String detectorID; + private String configID; - public DeleteAnomalyDetectorRequest(StreamInput in) throws IOException { + public DeleteConfigRequest(StreamInput in) throws IOException { super(in); - this.detectorID = in.readString(); + this.configID = in.readString(); } - public DeleteAnomalyDetectorRequest(String detectorID) { + public DeleteConfigRequest(String detectorID) { super(); - this.detectorID = detectorID; + this.configID = detectorID; } - public String getDetectorID() { - return detectorID; + public String getConfigID() { + return configID; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(detectorID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(detectorID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java similarity index 67% rename from src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java index d10eef4c3..6af6b9fcc 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -22,26 +22,26 @@ */ public class DeleteModelNodeRequest extends TransportRequest { - private String adID; + private String configID; DeleteModelNodeRequest() {} - DeleteModelNodeRequest(StreamInput in) throws IOException { + public DeleteModelNodeRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } - DeleteModelNodeRequest(DeleteModelRequest request) { - this.adID = request.getAdID(); + public DeleteModelNodeRequest(DeleteModelRequest request) { + this.configID = request.getAdID(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java similarity index 96% rename from src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java index c71e7368c..a57cb0d30 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java similarity index 77% rename from src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java index 9ec58acda..d6b119e6a 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,24 +17,24 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.nodes.BaseNodesRequest; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; /** * Request should be sent from the handler logic of transport delete detector API * */ public class DeleteModelRequest extends BaseNodesRequest implements ToXContentObject { - private String adID; + private String configID; public String getAdID() { - return adID; + return configID; } public DeleteModelRequest() { @@ -43,25 +43,25 @@ public DeleteModelRequest() { public DeleteModelRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } public DeleteModelRequest(String adID, DiscoveryNode... nodes) { super(nodes); - this.adID = adID; + this.configID = adID; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } @@ -69,7 +69,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java similarity index 97% rename from src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java index f2cbe2468..a2154481a 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java b/src/main/java/org/opensearch/timeseries/transport/EntityProfileRequest.java similarity index 80% rename from src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java rename to src/main/java/org/opensearch/timeseries/transport/EntityProfileRequest.java index 7e4054a8a..edee7f379 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityProfileRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -19,27 +19,27 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.EntityProfileName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; public class EntityProfileRequest extends ActionRequest implements ToXContentObject { public static final String ENTITY = "entity"; public static final String PROFILES = "profiles"; - private String adID; + private String configID; // changed from String to Entity since 1.1 private Entity entityValue; private Set profilesToCollect; public EntityProfileRequest(StreamInput in) throws IOException { super(in); - adID = in.readString(); + configID = in.readString(); entityValue = new Entity(in); int size = in.readVInt(); @@ -53,13 +53,13 @@ public EntityProfileRequest(StreamInput in) throws IOException { public EntityProfileRequest(String adID, Entity entityValue, Set profilesToCollect) { super(); - this.adID = adID; + this.configID = adID; this.entityValue = entityValue; this.profilesToCollect = profilesToCollect; } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } public Entity getEntityValue() { @@ -73,7 +73,7 @@ public Set getProfilesToCollect() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); entityValue.writeTo(out); out.writeVInt(profilesToCollect.size()); @@ -85,14 +85,14 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } if (entityValue == null) { validationException = addValidationError("Entity value is missing", validationException); } if (profilesToCollect == null || profilesToCollect.isEmpty()) { - validationException = addValidationError(ADCommonMessages.EMPTY_PROFILES_COLLECT, validationException); + validationException = addValidationError(CommonMessages.EMPTY_PROFILES_COLLECT, validationException); } return validationException; } @@ -100,7 +100,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.field(ENTITY, entityValue); builder.field(PROFILES, profilesToCollect); builder.endObject(); diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java b/src/main/java/org/opensearch/timeseries/transport/EntityProfileResponse.java similarity index 95% rename from src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java rename to src/main/java/org/opensearch/timeseries/transport/EntityProfileResponse.java index 1b8b51da2..8d7dc5843 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityProfileResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Optional; @@ -17,13 +17,13 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfileOnNode; public class EntityProfileResponse extends ActionResponse implements ToXContentObject { public static final String ACTIVE = "active"; @@ -128,7 +128,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(TOTAL_UPDATES, totalUpdates); } if (modelProfile != null) { - builder.field(ADCommonName.MODEL, modelProfile); + builder.field(CommonName.MODEL, modelProfile); } builder.endObject(); return builder; @@ -140,7 +140,7 @@ public String toString() { builder.append(ACTIVE, isActive); builder.append(LAST_ACTIVE_TS, lastActiveMs); builder.append(TOTAL_UPDATES, totalUpdates); - builder.append(ADCommonName.MODEL, modelProfile); + builder.append(CommonName.MODEL, modelProfile); return builder.toString(); } diff --git a/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java new file mode 100644 index 000000000..39b0eb85e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java @@ -0,0 +1,261 @@ +package org.opensearch.timeseries.transport; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.SaveResultStrategy; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Shared code to implement an entity result transportation + * (e.g., EntityForecastResultTransportAction) + * + */ +public class EntityResultProcessor, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, SaveResultStrategyType extends SaveResultStrategy, ColdStartWorkerType extends ColdStartWorker, HCCheckpointReadWorkerType extends CheckpointReadWorker, ColdEntityWorkerType extends ColdEntityWorker> { + + private static final Logger LOG = LogManager.getLogger(EntityResultProcessor.class); + + private CacheProvider cache; + private ModelManagerType modelManager; + private Stats stats; + private ColdStartWorkerType entityColdStartWorker; + private HCCheckpointReadWorkerType checkpointReadQueue; + private ColdEntityWorkerType coldEntityQueue; + private SaveResultStrategyType saveResultStrategy; + private StatNames modelCorruptionStat; + + public EntityResultProcessor( + CacheProvider cache, + ModelManagerType manager, + Stats stats, + ColdStartWorkerType entityColdStartWorker, + HCCheckpointReadWorkerType checkpointReadQueue, + ColdEntityWorkerType coldEntityQueue, + SaveResultStrategyType saveResultStrategy, + StatNames modelCorruptionStat + ) { + this.cache = cache; + this.modelManager = manager; + this.stats = stats; + this.entityColdStartWorker = entityColdStartWorker; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.saveResultStrategy = saveResultStrategy; + this.modelCorruptionStat = modelCorruptionStat; + } + + public ActionListener> onGetConfig( + ActionListener listener, + String forecasterId, + EntityResultRequest request, + Optional prevException, + AnalysisType analysisType + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(forecasterId, "Config " + forecasterId + " is not available.", false)); + return; + } + + Config config = configOptional.get(); + + if (request.getEntities() == null) { + listener.onFailure(new EndRunException(forecasterId, "Fail to get any entities from request.", false)); + return; + } + + Map cacheMissEntities = new HashMap<>(); + for (Entry entityEntry : request.getEntities().entrySet()) { + Entity entity = entityEntry.getKey(); + + if (isEntityFromOldNodeMsg(entity) && config.getCategoryFields() != null && config.getCategoryFields().size() == 1) { + Map attrValues = entity.getAttributes(); + // handle a request from a version before OpenSearch 1.1. + entity = Entity.createSingleAttributeEntity(config.getCategoryFields().get(0), attrValues.get(CommonName.EMPTY_FIELD)); + } + + Optional modelIdOptional = entity.getModelId(forecasterId); + if (modelIdOptional.isEmpty()) { + continue; + } + + String modelId = modelIdOptional.get(); + double[] datapoint = entityEntry.getValue(); + ModelState entityModel = cache.get().get(modelId, config); + if (entityModel == null) { + // cache miss + cacheMissEntities.put(entity, datapoint); + continue; + } + try { + IntermediateResultType result = modelManager + .getResult( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + entityModel, + modelId, + Optional.ofNullable(entity), + config, + request.getTaskId() + ); + + saveResultStrategy + .saveResult( + result, + config, + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + modelId, + datapoint, + Optional.of(entity), + request.getTaskId() + ); + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + stats.getStat(modelCorruptionStat.getName()).increment(); + cache.get().removeModel(forecasterId, modelId); + entityColdStartWorker + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + datapoint, + request.getStart(), + entity, + request.getTaskId() + ) + ); + } + } + + // split hot and cold entities + Pair, List> hotColdEntities = cache + .get() + .selectUpdateCandidate(cacheMissEntities.keySet(), forecasterId, config); + + List hotEntityRequests = new ArrayList<>(); + List coldEntityRequests = new ArrayList<>(); + + for (Entity hotEntity : hotColdEntities.getLeft()) { + double[] hotEntityValue = cacheMissEntities.get(hotEntity); + if (hotEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); + continue; + } + hotEntityRequests + .add( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + // hot entities has MEDIUM priority + RequestPriority.MEDIUM, + hotEntityValue, + request.getStart(), + hotEntity, + request.getTaskId() + ) + ); + } + + for (Entity coldEntity : hotColdEntities.getRight()) { + double[] coldEntityValue = cacheMissEntities.get(coldEntity); + if (coldEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); + continue; + } + coldEntityRequests + .add( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + // cold entities has LOW priority + RequestPriority.LOW, + coldEntityValue, + request.getStart(), + coldEntity, + request.getTaskId() + ) + ); + } + + checkpointReadQueue.putAll(hotEntityRequests); + coldEntityQueue.putAll(coldEntityRequests); + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + }, exception -> { + LOG + .error( + new ParameterizedMessage( + "fail to get entity's analysis result for config [{}]: start: [{}], end: [{}]", + forecasterId, + request.getStart(), + request.getEnd() + ), + exception + ); + listener.onFailure(exception); + }); + } + + /** + * Whether the received entity comes from an node that doesn't support multi-category fields. + * This can happen during rolling-upgrade or blue/green deployment. + * + * Specifically, when receiving an EntityResultRequest from an incompatible node, + * EntityResultRequest(StreamInput in) gets an String that represents an entity. + * But Entity class requires both an category field name and value. Since we + * don't have access to detector config in EntityResultRequest(StreamInput in), + * we put CommonName.EMPTY_FIELD as the placeholder. In this method, + * we use the same CommonName.EMPTY_FIELD to check if the deserialized entity + * comes from an incompatible node. If it is, we will add the field name back + * as EntityResultTranportAction has access to the detector config object. + * + * @param categoricalValues deserialized Entity from inbound message. + * @return Whether the received entity comes from an node that doesn't support multi-category fields. + */ + private boolean isEntityFromOldNodeMsg(Entity categoricalValues) { + Map attrValues = categoricalValues.getAttributes(); + return (attrValues != null && attrValues.containsKey(CommonName.EMPTY_FIELD)); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java similarity index 69% rename from src/main/java/org/opensearch/ad/transport/EntityResultRequest.java rename to src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java index 91041f447..8177178f4 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,32 +17,31 @@ import java.util.Locale; import java.util.Map; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; public class EntityResultRequest extends ActionRequest implements ToXContentObject { - private static final Logger LOG = LogManager.getLogger(EntityResultRequest.class); - private String detectorId; + protected String configId; // changed from Map to Map - private Map entities; - private long start; - private long end; + protected Map entities; + // data start/end time epoch + protected long start; + protected long end; + protected AnalysisType analysisType; + protected String taskId; public EntityResultRequest(StreamInput in) throws IOException { super(in); - this.detectorId = in.readString(); + this.configId = in.readString(); // guarded with version check. Just in case we receive requests from older node where we use String // to represent an entity @@ -50,18 +49,33 @@ public EntityResultRequest(StreamInput in) throws IOException { this.start = in.readLong(); this.end = in.readLong(); + + // newly added + if (in.available() > 0) { + analysisType = in.readEnum(AnalysisType.class); + taskId = in.readOptionalString(); + } } - public EntityResultRequest(String detectorId, Map entities, long start, long end) { + public EntityResultRequest( + String configId, + Map entities, + long start, + long end, + AnalysisType analysisType, + String taskId + ) { super(); - this.detectorId = detectorId; + this.configId = configId; this.entities = entities; this.start = start; this.end = end; + this.analysisType = analysisType; + this.taskId = taskId; } - public String getId() { - return this.detectorId; + public String getConfigId() { + return this.configId; } public Map getEntities() { @@ -76,23 +90,33 @@ public long getEnd() { return this.end; } + public AnalysisType getAnalysisType() { + return analysisType; + } + + public String getTaskId() { + return taskId; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(this.detectorId); + out.writeString(this.configId); // guarded with version check. Just in case we send requests to older node where we use String // to represent an entity out.writeMap(entities, (s, e) -> e.writeTo(s), StreamOutput::writeDoubleArray); out.writeLong(this.start); out.writeLong(this.end); + out.writeEnum(analysisType); + out.writeOptionalString(taskId); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(detectorId)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configId)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } if (start <= 0 || end <= 0 || start > end) { validationException = addValidationError( @@ -106,7 +130,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, detectorId); + builder.field(CommonName.CONFIG_ID_KEY, configId); builder.field(CommonName.START_JSON_KEY, start); builder.field(CommonName.END_JSON_KEY, end); builder.startArray(CommonName.ENTITIES_JSON_KEY); @@ -119,6 +143,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } } builder.endArray(); + builder.field(CommonName.ANALYSIS_TYPE_FIELD, analysisType); + builder.field(CommonName.TASK_ID_FIELD, taskId); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java new file mode 100644 index 000000000..4c2895378 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ForecastRunOnceProfileNodeRequest.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.transport.ForecastRunOnceProfileRequest; +import org.opensearch.transport.TransportRequest; + +public class ForecastRunOnceProfileNodeRequest extends TransportRequest { + private final ForecastRunOnceProfileRequest request; + + public ForecastRunOnceProfileNodeRequest(StreamInput in) throws IOException { + super(in); + request = new ForecastRunOnceProfileRequest(in); + } + + public ForecastRunOnceProfileNodeRequest(ForecastRunOnceProfileRequest request) { + this.request = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } + + public String getConfigId() { + return request.getConfigId(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java index aef29626d..1aed87c66 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -19,9 +19,9 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.timeseries.model.Entity; -public class GetAnomalyDetectorRequest extends ActionRequest { +public class GetConfigRequest extends ActionRequest { - private String detectorID; + private String configID; private long version; private boolean returnJob; private boolean returnTask; @@ -30,9 +30,9 @@ public class GetAnomalyDetectorRequest extends ActionRequest { private boolean all; private Entity entity; - public GetAnomalyDetectorRequest(StreamInput in) throws IOException { + public GetConfigRequest(StreamInput in) throws IOException { super(in); - detectorID = in.readString(); + configID = in.readString(); version = in.readLong(); returnJob = in.readBoolean(); returnTask = in.readBoolean(); @@ -44,7 +44,7 @@ public GetAnomalyDetectorRequest(StreamInput in) throws IOException { } } - public GetAnomalyDetectorRequest( + public GetConfigRequest( String detectorID, long version, boolean returnJob, @@ -55,7 +55,7 @@ public GetAnomalyDetectorRequest( Entity entity ) { super(); - this.detectorID = detectorID; + this.configID = detectorID; this.version = version; this.returnJob = returnJob; this.returnTask = returnTask; @@ -65,8 +65,8 @@ public GetAnomalyDetectorRequest( this.entity = entity; } - public String getDetectorID() { - return detectorID; + public String getConfigID() { + return configID; } public long getVersion() { @@ -100,7 +100,7 @@ public Entity getEntity() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(detectorID); + out.writeString(configID); out.writeLong(version); out.writeBoolean(returnJob); out.writeBoolean(returnTask); diff --git a/src/main/java/org/opensearch/timeseries/transport/JobRequest.java b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java new file mode 100644 index 000000000..98b56930f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.DateRange; + +public class JobRequest extends ActionRequest { + + private String configID; + private DateRange dateRange; + private boolean historical; + private String rawPath; + + public JobRequest(StreamInput in) throws IOException { + super(in); + configID = in.readString(); + rawPath = in.readString(); + if (in.readBoolean()) { + dateRange = new DateRange(in); + } + historical = in.readBoolean(); + } + + public JobRequest(String detectorID, String rawPath) { + this(detectorID, null, false, rawPath); + } + + /** + * Constructor function. + * + * The dateRange and historical boolean can be passed in individually. + * The historical flag is for stopping analysis, the dateRange is for + * starting analysis. It's ok if historical is true but dateRange is + * null. + * + * @param configID config identifier + * @param dateRange analysis date range + * @param historical historical analysis or not + * @param rawPath raw request path + */ + public JobRequest(String configID, DateRange dateRange, boolean historical, String rawPath) { + super(); + this.configID = configID; + this.dateRange = dateRange; + this.historical = historical; + this.rawPath = rawPath; + } + + public String getConfigID() { + return configID; + } + + public DateRange getDateRange() { + return dateRange; + } + + public String getRawPath() { + return rawPath; + } + + public boolean isHistorical() { + return historical; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configID); + out.writeString(rawPath); + if (dateRange != null) { + out.writeBoolean(true); + dateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeBoolean(historical); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/JobResponse.java b/src/main/java/org/opensearch/timeseries/transport/JobResponse.java new file mode 100644 index 000000000..faa7df2c8 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/JobResponse.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class JobResponse extends ActionResponse implements ToXContentObject { + private final String id; + + public JobResponse(StreamInput in) throws IOException { + super(in); + id = in.readString(); + } + + public JobResponse(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field(RestHandlerUtils._ID, id).endObject(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeRequest.java similarity index 74% rename from src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileNodeRequest.java index d3db87d33..a5ebfb61a 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeRequest.java @@ -9,14 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Set; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.transport.TransportRequest; /** @@ -39,8 +39,8 @@ public ProfileNodeRequest(ProfileRequest request) { this.request = request; } - public String getId() { - return request.getId(); + public String getConfigId() { + return request.getConfigId(); } /** @@ -48,16 +48,17 @@ public String getId() { * * @return the set that contains the profile names marked for retrieval */ - public Set getProfilesToBeRetrieved() { + public Set getProfilesToBeRetrieved() { return request.getProfilesToBeRetrieved(); } /** * - * @return Whether this is about a multi-entity detector or not + * @return Whether the models are stored in priority cache. AD single stream models are stored in ModelManager. + * Other models are stored in priority cache. */ - public boolean isForMultiEntityDetector() { - return request.isForMultiEntityDetector(); + public boolean isModelInPriorityCache() { + return request.isModelInPriorityCache(); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeResponse.java similarity index 92% rename from src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileNodeResponse.java index 9517f6add..37be94232 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileNodeResponse.java @@ -9,21 +9,20 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; import java.util.Map; import org.opensearch.action.support.nodes.BaseNodeResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfile; /** * Profile response on a node @@ -137,12 +136,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endObject(); - builder.field(ADCommonName.SHINGLE_SIZE, shingleSize); - builder.field(ADCommonName.ACTIVE_ENTITIES, activeEntities); - builder.field(ADCommonName.TOTAL_UPDATES, totalUpdates); + builder.field(CommonName.SHINGLE_SIZE, shingleSize); + builder.field(CommonName.ACTIVE_ENTITIES, activeEntities); + builder.field(CommonName.TOTAL_UPDATES, totalUpdates); - builder.field(ADCommonName.MODEL_COUNT, modelCount); - builder.startArray(ADCommonName.MODELS); + builder.field(CommonName.MODEL_COUNT, modelCount); + builder.startArray(CommonName.MODELS); for (ModelProfile modelProfile : modelProfiles) { builder.startObject(); modelProfile.toXContent(builder, params); diff --git a/src/main/java/org/opensearch/ad/transport/ProfileRequest.java b/src/main/java/org/opensearch/timeseries/transport/ProfileRequest.java similarity index 56% rename from src/main/java/org/opensearch/ad/transport/ProfileRequest.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileRequest.java index ea779e733..07cdd3d3c 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileRequest.java @@ -9,73 +9,68 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.HashSet; import java.util.Set; import org.opensearch.action.support.nodes.BaseNodesRequest; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.ProfileName; /** * implements a request to obtain profiles about an AD detector */ public class ProfileRequest extends BaseNodesRequest { - private Set profilesToBeRetrieved; - private String detectorId; - private boolean forMultiEntityDetector; + private Set profilesToBeRetrieved; + private String configId; + private boolean modelInPriorityCache; public ProfileRequest(StreamInput in) throws IOException { super(in); int size = in.readVInt(); - profilesToBeRetrieved = new HashSet(); + profilesToBeRetrieved = new HashSet(); if (size != 0) { for (int i = 0; i < size; i++) { - profilesToBeRetrieved.add(in.readEnum(DetectorProfileName.class)); + profilesToBeRetrieved.add(in.readEnum(ProfileName.class)); } } - detectorId = in.readString(); - forMultiEntityDetector = in.readBoolean(); + configId = in.readString(); + modelInPriorityCache = in.readBoolean(); } /** * Constructor * - * @param detectorId detector's id + * @param configId config id * @param profilesToBeRetrieved profiles to be retrieved - * @param forMultiEntityDetector whether the request is for a multi-entity detector + * @param forHC whether the request is for an high-cardinality analysis * @param nodes nodes of nodes' profiles to be retrieved */ - public ProfileRequest( - String detectorId, - Set profilesToBeRetrieved, - boolean forMultiEntityDetector, - DiscoveryNode... nodes - ) { + public ProfileRequest(String configId, Set profilesToBeRetrieved, boolean forHC, DiscoveryNode... nodes) { super(nodes); - this.detectorId = detectorId; + this.configId = configId; this.profilesToBeRetrieved = profilesToBeRetrieved; - this.forMultiEntityDetector = forMultiEntityDetector; + this.modelInPriorityCache = forHC; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeVInt(profilesToBeRetrieved.size()); - for (DetectorProfileName profile : profilesToBeRetrieved) { + for (ProfileName profile : profilesToBeRetrieved) { out.writeEnum(profile); } - out.writeString(detectorId); - out.writeBoolean(forMultiEntityDetector); + out.writeString(configId); + out.writeBoolean(modelInPriorityCache); } - public String getId() { - return detectorId; + public String getConfigId() { + return configId; } /** @@ -83,15 +78,16 @@ public String getId() { * * @return the set that contains the profile names marked for retrieval */ - public Set getProfilesToBeRetrieved() { + public Set getProfilesToBeRetrieved() { return profilesToBeRetrieved; } /** * - * @return Whether this is about a multi-entity detector or not + * @return Whether the models are stored in priority cache. AD single stream models are stored in ModelManager. + * Other models are stored in priority cache. */ - public boolean isForMultiEntityDetector() { - return forMultiEntityDetector; + public boolean isModelInPriorityCache() { + return modelInPriorityCache; } } diff --git a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java b/src/main/java/org/opensearch/timeseries/transport/ProfileResponse.java similarity index 90% rename from src/main/java/org/opensearch/ad/transport/ProfileResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ProfileResponse.java index 11ba28163..9d1680430 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ProfileResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.ArrayList; @@ -20,14 +20,14 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; /** * This class consists of the aggregated responses from the nodes @@ -35,13 +35,13 @@ public class ProfileResponse extends BaseNodesResponse implements ToXContentFragment { private static final Logger LOG = LogManager.getLogger(ProfileResponse.class); // filed name in toXContent - static final String COORDINATING_NODE = ADCommonName.COORDINATING_NODE; - static final String SHINGLE_SIZE = ADCommonName.SHINGLE_SIZE; - static final String TOTAL_SIZE = ADCommonName.TOTAL_SIZE_IN_BYTES; - static final String ACTIVE_ENTITY = ADCommonName.ACTIVE_ENTITIES; - static final String MODELS = ADCommonName.MODELS; - static final String TOTAL_UPDATES = ADCommonName.TOTAL_UPDATES; - static final String MODEL_COUNT = ADCommonName.MODEL_COUNT; + public static final String COORDINATING_NODE = CommonName.COORDINATING_NODE; + public static final String SHINGLE_SIZE = CommonName.SHINGLE_SIZE; + public static final String TOTAL_SIZE = CommonName.TOTAL_SIZE_IN_BYTES; + static final String ACTIVE_ENTITY = CommonName.ACTIVE_ENTITIES; + public static final String MODELS = CommonName.MODELS; + static final String TOTAL_UPDATES = CommonName.TOTAL_UPDATES; + static final String MODEL_COUNT = CommonName.MODEL_COUNT; // changed from ModelProfile to ModelProfileOnNode since Opensearch 1.1 private ModelProfileOnNode[] modelProfile; diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java new file mode 100644 index 000000000..cd8efc9de --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ValidateActions; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ResultBulkRequest> extends + ActionRequest + implements + Writeable { + private final List results; + + public ResultBulkRequest() { + results = new ArrayList<>(); + } + + public ResultBulkRequest(StreamInput in, Writeable.Reader reader) throws IOException { + super(in); + int size = in.readVInt(); + results = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + results.add(reader.read(in)); + } + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (results.isEmpty()) { + validationException = ValidateActions.addValidationError(CommonMessages.NO_REQUESTS_ADDED_ERR, validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(results.size()); + for (ResultWriteRequestType result : results) { + result.writeTo(out); + } + } + + /** + * + * @return all of the results to send + */ + public List getAnomalyResults() { + return results; + } + + /** + * Add result to send + * @param resultWriteRequest The result write request + */ + public void add(ResultWriteRequestType resultWriteRequest) { + results.add(resultWriteRequest); + } + + /** + * + * @return total index requests + */ + public int numberOfActions() { + return results.size(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java index 70768311c..570bddca2 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.ArrayList; @@ -21,7 +21,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -public class ADResultBulkResponse extends ActionResponse { +public class ResultBulkResponse extends ActionResponse { public static final String RETRY_REQUESTS_JSON_KEY = "retry_requests"; private List retryRequests; @@ -30,15 +30,15 @@ public class ADResultBulkResponse extends ActionResponse { * * @param retryRequests a list of requests to retry */ - public ADResultBulkResponse(List retryRequests) { + public ResultBulkResponse(List retryRequests) { this.retryRequests = retryRequests; } - public ADResultBulkResponse() { + public ResultBulkResponse() { this.retryRequests = null; } - public ADResultBulkResponse(StreamInput in) throws IOException { + public ResultBulkResponse(StreamInput in) throws IOException { int size = in.readInt(); if (size > 0) { retryRequests = new ArrayList<>(size); diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java new file mode 100644 index 000000000..f070c38c6 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.index.IndexingPressure.MAX_INDEXING_BYTES; + +import java.io.IOException; +import java.util.List; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexingPressure; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.util.BulkUtil; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +@SuppressWarnings("rawtypes") +public abstract class ResultBulkTransportAction, ResultBulkRequestType extends ResultBulkRequest> + extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(ResultBulkTransportAction.class); + protected IndexingPressure indexingPressure; + private final long primaryAndCoordinatingLimits; + protected float softLimit; + protected float hardLimit; + protected String indexName; + private Client client; + protected Random random; + + public ResultBulkTransportAction( + String actionName, + TransportService transportService, + ActionFilters actionFilters, + IndexingPressure indexingPressure, + Settings settings, + Client client, + float softLimit, + float hardLimit, + String indexName, + Writeable.Reader requestReader + ) { + super(actionName, transportService, actionFilters, requestReader, ThreadPool.Names.SAME); + this.indexingPressure = indexingPressure; + this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); + this.client = client; + + this.softLimit = softLimit; + this.hardLimit = hardLimit; + this.indexName = indexName; + + // random seed is 42. Can be any number + this.random = new Random(42); + } + + @Override + protected void doExecute(Task task, ResultBulkRequestType request, ActionListener listener) { + // Concurrent indexing memory limit = 10% of heap + // indexing pressure = indexing bytes / indexing limit + // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index + // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). + long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); + float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; + @SuppressWarnings("rawtypes") + List results = request.getAnomalyResults(); + + if (results == null || results.size() < 1) { + listener.onResponse(new ResultBulkResponse()); + } + + BulkRequest bulkRequest = prepareBulkRequest(indexingPressurePercent, request); + + if (bulkRequest.numberOfActions() > 0) { + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { + List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); + listener.onResponse(new ResultBulkResponse(failedRequests)); + }, e -> { + LOG.error("Failed to bulk index AD result", e); + listener.onFailure(e); + })); + } else { + listener.onResponse(new ResultBulkResponse()); + } + } + + protected abstract BulkRequest prepareBulkRequest(float indexingPressurePercent, ResultBulkRequestType request); + + protected void addResult(BulkRequest bulkRequest, ToXContentObject result, String resultIndex) { + String index = resultIndex == null ? indexName : resultIndex; + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(index).source(result.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (IOException e) { + LOG.error("Failed to prepare bulk index request for index " + index, e); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java new file mode 100644 index 000000000..2dcf6bde4 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java @@ -0,0 +1,813 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.net.ConnectException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.NetworkExceptionHelper; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.node.NodeClosedException; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.CompositeRetriever; +import org.opensearch.timeseries.feature.CompositeRetriever.PageIterator; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.ActionNotFoundTransportException; +import org.opensearch.transport.ConnectTransportException; +import org.opensearch.transport.NodeNotConnectedException; +import org.opensearch.transport.ReceiveTimeoutTransportException; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportService; + +public abstract class ResultProcessor, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager> { + + private static final Logger LOG = LogManager.getLogger(ResultProcessor.class); + + static final String WAIT_FOR_THRESHOLD_ERR_MSG = "Exception in waiting for threshold result"; + + static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes."; + + public static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: "; + + public static final String NULL_RESPONSE = "Received null response from"; + + public static final String INDEX_READ_BLOCKED = "Cannot read user index due to read block."; + + public static final String READ_WRITE_BLOCKED = "Cannot read/write due to global block."; + + public static final String NODE_UNRESPONSIVE_ERR_MSG = "Model node is unresponsive. Mute node"; + + protected final TransportRequestOptions option; + private String entityResultAction; + protected Class transportResultResponseClazz; + private StatNames hcRequestCountStat; + private String threadPoolName; + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + private final float intervalRatioForRequest; + private int maxEntitiesPerInterval; + private int pageSize; + protected final ThreadPool threadPool; + private final HashRing hashRing; + protected final NodeStateManager nodeStateManager; + protected final TransportService transportService; + private final Stats timeSeriesStats; + private final TaskManagerType realTimeTaskManager; + private NamedXContentRegistry xContentRegistry; + private final Client client; + private final SecurityClientUtil clientUtil; + private Settings settings; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final ClusterService clusterService; + protected final FeatureManager featureManager; + protected final AnalysisType analysisType; + + protected boolean runOnce; + + public ResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + String threadPoolName, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + Stats timeSeriesStats, + TaskManagerType realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager, + Setting maxEntitiesPerIntervalSetting, + Setting pageSizeSetting, + AnalysisType context, + boolean runOnce + ) { + this.option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(requestTimeoutSetting.get(settings)) + .build(); + this.intervalRatioForRequest = intervalRatioForRequests; + + this.maxEntitiesPerInterval = maxEntitiesPerIntervalSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxEntitiesPerIntervalSetting, it -> maxEntitiesPerInterval = it); + + this.pageSize = pageSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(pageSizeSetting, it -> pageSize = it); + + this.entityResultAction = entityResultAction; + this.hcRequestCountStat = hcRequestCountStat; + this.threadPool = threadPool; + this.hashRing = hashRing; + this.nodeStateManager = nodeStateManager; + this.transportService = transportService; + this.timeSeriesStats = timeSeriesStats; + this.realTimeTaskManager = realTimeTaskManager; + this.xContentRegistry = xContentRegistry; + this.client = client; + this.clientUtil = clientUtil; + this.settings = settings; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.clusterService = clusterService; + this.transportResultResponseClazz = transportResultResponseClazz; + this.featureManager = featureManager; + this.analysisType = context; + this.threadPoolName = threadPoolName; + this.runOnce = runOnce; + } + + /** + * didn't use ActionListener.wrap so that I can + * 1) use this to refer to the listener inside the listener + * 2) pass parameters using constructors + * + */ + class PageListener implements ActionListener { + private PageIterator pageIterator; + private String configId; + private long dataStartTime; + private long dataEndTime; + private Runnable finishRunnable; + private String taskId; + + PageListener( + PageIterator pageIterator, + String detectorId, + long dataStartTime, + long dataEndTime, + Runnable finishRunnable, + String taskId + ) { + this.pageIterator = pageIterator; + this.configId = detectorId; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + this.finishRunnable = finishRunnable; + this.taskId = taskId; + } + + @Override + public void onResponse(CompositeRetriever.Page entityFeatures) { + if (pageIterator.hasNext()) { + pageIterator.next(this); + } else { + finishRunnable.run(); + } + if (entityFeatures != null && false == entityFeatures.isEmpty()) { + // wrap expensive operation inside ad threadpool + threadPool.executor(threadPoolName).execute(() -> { + try { + + Set>> node2Entities = entityFeatures + .getResults() + .entrySet() + .stream() + .filter(e -> hashRing.getOwningNodeWithSameLocalVersionForRealtime(e.getKey().toString()).isPresent()) + .collect( + Collectors + .groupingBy( + // from entity name to its node + e -> hashRing.getOwningNodeWithSameLocalVersionForRealtime(e.getKey().toString()).get(), + Collectors.toMap(Entry::getKey, Entry::getValue) + ) + ) + .entrySet(); + + Iterator>> iterator = node2Entities.iterator(); + + while (iterator.hasNext()) { + Entry> entry = iterator.next(); + DiscoveryNode modelNode = entry.getKey(); + if (modelNode == null) { + iterator.remove(); + continue; + } + String modelNodeId = modelNode.getId(); + if (nodeStateManager.isMuted(modelNodeId, configId)) { + LOG + .info( + String + .format( + Locale.ROOT, + ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", + modelNodeId, + configId + ) + ); + iterator.remove(); + } + } + + final AtomicReference failure = new AtomicReference<>(); + node2Entities.stream().forEach(nodeEntity -> { + DiscoveryNode node = nodeEntity.getKey(); + transportService + .sendRequest( + node, + entityResultAction, + new EntityResultRequest( + configId, + nodeEntity.getValue(), + dataStartTime, + dataEndTime, + analysisType, + taskId + ), + option, + new ActionListenerResponseHandler<>( + new ErrorResponseListener(node.getId(), configId, failure), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + }); + + } catch (Exception e) { + LOG.error("Unexpected exception", e); + handleException(e); + } + }); + } + } + + @Override + public void onFailure(Exception e) { + LOG.error("Unexpetected exception", e); + handleException(e); + } + + private void handleException(Exception e) { + Exception convertedException = convertedQueryFailureException(e, configId); + if (false == (convertedException instanceof TimeSeriesException)) { + Throwable cause = ExceptionsHelper.unwrapCause(convertedException); + convertedException = new InternalFailure(configId, cause); + } + nodeStateManager.setException(configId, convertedException); + } + } + + public ActionListener> onGetConfig( + ActionListener listener, + String configID, + TransportResultRequestType request, + Optional> hcDetectors + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(configID, "config is not available.", true)); + return; + } + + Config config = configOptional.get(); + // no stat increment in runOnce where hcDetectors is empty. + if (config.isHighCardinality() && hcDetectors.isPresent()) { + hcDetectors.get().add(configID); + timeSeriesStats.getStat(hcRequestCountStat.getName()).increment(); + } + + if (request.getStart() <= 0) { + long duration = config.getIntervalInMilliseconds(); + long executionStartTime = request.getEnd() - duration; + + request.setStart(executionStartTime); + } + long delayMillis = Optional + .ofNullable((IntervalTimeConfiguration) config.getWindowDelay()) + .map(t -> t.toDuration().toMillis()) + .orElse(0L); + long dataStartTime = request.getStart() - delayMillis; + long dataEndTime = request.getEnd() - delayMillis; + + if (runOnce) { + realTimeTaskManager.createRunOnceTaskAndCleanupStaleTasks(configID, config, transportService, ActionListener.wrap(r -> { + if (r == null) { + LOG.error("Unexpected empty new task for " + configID); + listener + .onFailure( + new OpenSearchStatusException( + "Failed to bootstrap run once task for " + configID, + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + return; + } + executeAnalysis(listener, configID, request, config, dataStartTime, dataEndTime, r.getTaskId()); + }, e -> { + LOG.error("Failed to init run once task for " + configID, e); + listener + .onFailure( + new OpenSearchStatusException( + "Failed to bootstrap run once task for " + configID, + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + })); + } else { + realTimeTaskManager + .initRealtimeTaskCacheAndCleanupStaleCache( + configID, + config, + transportService, + ActionListener + .runAfter( + initRealtimeTaskListener(configID), + () -> executeAnalysis(listener, configID, request, config, dataStartTime, dataEndTime, null) + ) + ); + } + + }, exception -> ResultProcessor.handleExecuteException(exception, listener, configID)); + } + + private ActionListener initRealtimeTaskListener(String configId) { + return ActionListener.wrap(r -> { + if (r) { + LOG.debug("Realtime task initied for config {}", configId); + } + }, e -> LOG.error("Failed to init realtime task for " + configId, e)); + } + + private void executeAnalysis( + ActionListener listener, + String configID, + ResultRequest request, + Config config, + long dataStartTime, + long dataEndTime, + String taskId + ) { + // HC logic starts here + if (config.isHighCardinality()) { + Optional previousException = nodeStateManager.fetchExceptionAndClear(configID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of [{}]", configID), exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + } + + // assume request are in epoch milliseconds + long nextDetectionStartTime = request.getEnd() + (long) (config.getIntervalInMilliseconds() * intervalRatioForRequest); + + CompositeRetriever compositeRetriever = new CompositeRetriever( + dataStartTime, + dataEndTime, + config, + xContentRegistry, + client, + clientUtil, + nextDetectionStartTime, + settings, + maxEntitiesPerInterval, + pageSize, + indexNameExpressionResolver, + clusterService, + analysisType + ); + + PageIterator pageIterator = null; + + try { + pageIterator = compositeRetriever.iterator(); + } catch (Exception e) { + listener.onFailure(new EndRunException(config.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, false)); + return; + } + + Runnable finishRunnable = () -> { + // When pagination finishes or the time is up, + // return response or exceptions. + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else { + listener + .onResponse( + createResultResponse(new ArrayList(), null, null, config.getIntervalInMinutes(), true, taskId) + ); + } + }; + + PageListener getEntityFeatureslistener = new PageListener( + pageIterator, + configID, + dataStartTime, + dataEndTime, + finishRunnable, + taskId + ); + if (pageIterator.hasNext()) { + pageIterator.next(getEntityFeatureslistener); + } + + return; + } + + // HC logic ends and single entity logic starts here + // We are going to use only 1 model partition for a single stream detector. + // That's why we use 0 here. + String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(configID, 0); + Optional asRCFNode = hashRing.getOwningNodeWithSameLocalVersionForRealtime(rcfModelID); + if (!asRCFNode.isPresent()) { + listener.onFailure(new InternalFailure(configID, "RCF model node is not available.")); + return; + } + + DiscoveryNode rcfNode = asRCFNode.get(); + + if (!shouldStart(listener, configID, config, rcfNode.getId(), rcfModelID)) { + return; + } + + featureManager + .getCurrentFeatures( + config, + dataStartTime, + dataEndTime, + onFeatureResponseForSingleStreamConfig(configID, config, listener, rcfModelID, rcfNode, dataStartTime, dataEndTime, taskId) + ); + } + + protected void handleQueryFailure(Exception exception, ActionListener listener, String adID) { + Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); + + if (convertedQueryFailureException instanceof EndRunException) { + // invalid feature query + listener.onFailure(convertedQueryFailureException); + } else { + ResultProcessor.handleExecuteException(convertedQueryFailureException, listener, adID); + } + } + + /** + * Convert a query related exception to EndRunException + * + * These query exception can happen during the starting phase of the OpenSearch + * process. Thus, set the stopNow parameter of these EndRunException to false + * and confirm the EndRunException is not a false positive. + * + * @param exception Exception + * @param adID detector Id + * @return the converted exception if the exception is query related + */ + private Exception convertedQueryFailureException(Exception exception, String adID) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + return new EndRunException(adID, ResultProcessor.TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false) + .countedInStats(false); + } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { + // This is to catch invalid aggregation on wrong field type. For example, + // sum aggregation on text field. We should end detector run for such case. + return new EndRunException( + adID, + CommonMessages.INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), + exception, + false + ).countedInStats(false); + } + + return exception; + } + + protected void findException(Throwable cause, String configID, AtomicReference failure, String nodeId) { + if (cause == null) { + LOG.error(new ParameterizedMessage("Null input exception")); + return; + } + if (cause instanceof Error) { + // we cannot do anything with Error. + LOG.error(new ParameterizedMessage("Error during prediction for {}: ", configID), cause); + return; + } + + Exception causeException = (Exception) cause; + + if (causeException instanceof TimeSeriesException) { + failure.set(causeException); + } else if (causeException instanceof NotSerializableExceptionWrapper) { + // we only expect this happens on AD exceptions + Optional actualException = NotSerializedExceptionName + .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, configID); + if (actualException.isPresent()) { + TimeSeriesException adException = actualException.get(); + failure.set(adException); + if (adException instanceof ResourceNotFoundException) { + // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF + // 1.0 + // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node + // after consecutive failures. + nodeStateManager.addPressure(nodeId, configID); + } + } else { + // some unexpected bugs occur while predicting anomaly + failure.set(new EndRunException(configID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } else if (causeException instanceof OpenSearchTimeoutException) { + // we can have OpenSearchTimeoutException when a node tries to load RCF or + // threshold model + failure.set(new InternalFailure(configID, causeException)); + } else if (causeException instanceof IllegalArgumentException) { + // we can have IllegalArgumentException when a model is corrupted + failure.set(new InternalFailure(configID, causeException)); + } else { + // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. + // NoShardAvailableActionException) while predicting anomaly + failure.set(new EndRunException(configID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } + + private boolean invalidQuery(SearchPhaseExecutionException ex) { + // If all shards return bad request and failure cause is IllegalArgumentException, we + // consider the feature query is invalid and will not count the error in failure stats. + for (ShardSearchFailure failure : ex.shardFailures()) { + if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { + return false; + } + } + return true; + } + + /** + * Handle a prediction failure. Possibly (i.e., we don't always need to do that) + * convert the exception to a form that AD can recognize and handle and sets the + * input failure reference to the converted exception. + * + * @param e prediction exception + * @param adID Detector Id + * @param nodeID Node Id + * @param failure Parameter to receive the possibly converted function for the + * caller to deal with + */ + protected void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference failure) { + LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); + if (e == null) { + return; + } + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (hasConnectionIssue(cause)) { + handleConnectionException(nodeID, adID); + } else { + findException(cause, adID, failure, nodeID); + } + } + + /** + * Check if the input exception indicates connection issues. + * During blue-green deployment, we may see ActionNotFoundTransportException. + * Count that as connection issue and isolate that node if it continues to happen. + * + * @param e exception + * @return true if we get disconnected from the node or the node is not in the + * right state (being closed) or transport request times out (sent from TimeoutHandler.run) + */ + private boolean hasConnectionIssue(Throwable e) { + return e instanceof ConnectTransportException + || e instanceof NodeClosedException + || e instanceof ReceiveTimeoutTransportException + || e instanceof NodeNotConnectedException + || e instanceof ConnectException + || NetworkExceptionHelper.isCloseConnectionException(e) + || e instanceof ActionNotFoundTransportException; + } + + private void handleConnectionException(String node, String detectorId) { + final DiscoveryNodes nodes = clusterService.state().nodes(); + if (!nodes.nodeExists(node)) { + hashRing.buildCirclesForRealtime(); + return; + } + // rebuilding is not done or node is unresponsive + nodeStateManager.addPressure(node, detectorId); + } + + /** + * Since we need to read from customer index and write to anomaly result index, + * we need to make sure we can read and write. + * + * @param state Cluster state + * @return whether we have global block or not + */ + private boolean checkGlobalBlock(ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.READ) != null + || state.blocks().globalBlockedException(ClusterBlockLevel.WRITE) != null; + } + + /** + * Similar to checkGlobalBlock, we check block on the indices level. + * + * @param state Cluster state + * @param level block level + * @param indices the indices on which to check block + * @return whether any of the index has block on the level. + */ + private boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { + // the original index might be an index expression with wildcards like "log*", + // so we need to expand the expression to concrete index name + String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); + + return state.blocks().indicesBlockedException(level, concreteIndices) != null; + } + + /** + * Check if we should start anomaly prediction. + * + * @param listener listener to respond back to AnomalyResultRequest. + * @param adID detector ID + * @param detector detector instance corresponds to adID + * @param rcfNodeId the rcf model hosting node ID for adID + * @param rcfModelID the rcf model ID for adID + * @return if we can start anomaly prediction. + */ + private boolean shouldStart( + ActionListener listener, + String adID, + Config detector, + String rcfNodeId, + String rcfModelID + ) { + ClusterState state = clusterService.state(); + if (checkGlobalBlock(state)) { + listener.onFailure(new InternalFailure(adID, ResultProcessor.READ_WRITE_BLOCKED)); + return false; + } + + if (nodeStateManager.isMuted(rcfNodeId, adID)) { + listener + .onFailure( + new InternalFailure( + adID, + String + .format(Locale.ROOT, ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for rcf model %s", rcfNodeId, rcfModelID) + ) + ); + return false; + } + + if (checkIndicesBlocked(state, ClusterBlockLevel.READ, detector.getIndices().toArray(new String[0]))) { + listener.onFailure(new InternalFailure(adID, ResultProcessor.INDEX_READ_BLOCKED)); + return false; + } + + return true; + } + + public static void handleExecuteException(Exception ex, ActionListener listener, String id) { + if (ex instanceof ClientException) { + listener.onFailure(ex); + } else if (ex instanceof TimeSeriesException) { + listener.onFailure(new InternalFailure((TimeSeriesException) ex)); + } else { + Throwable cause = ExceptionsHelper.unwrapCause(ex); + listener.onFailure(new InternalFailure(id, cause)); + } + } + + public class ErrorResponseListener implements ActionListener { + private String nodeId; + private final String configId; + private AtomicReference failure; + + public ErrorResponseListener(String nodeId, String configId, AtomicReference failure) { + this.nodeId = nodeId; + this.configId = configId; + this.failure = failure; + } + + @Override + public void onResponse(AcknowledgedResponse response) { + try { + if (response.isAcknowledged() == false) { + LOG.error("Cannot send entities' features to {} for {}", nodeId, configId); + nodeStateManager.addPressure(nodeId, configId); + } else { + nodeStateManager.resetBackpressureCounter(nodeId, configId); + } + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, configId); + handleException(ex); + } + } + + @Override + public void onFailure(Exception e) { + try { + // e.g., we have connection issues with all of the nodes while restarting clusters + LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, configId), e); + + handleException(e); + + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, configId); + handleException(ex); + } + } + + private void handleException(Exception e) { + handlePredictionFailure(e, configId, nodeId, failure); + if (failure.get() != null) { + nodeStateManager.setException(configId, failure.get()); + } + } + } + + protected abstract ResultResponseType createResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ); + + protected abstract ActionListener onFeatureResponseForSingleStreamConfig( + String configId, + Config config, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime, + String taskId + ); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java new file mode 100644 index 000000000..c1e6a345f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; + +public abstract class ResultRequest extends ActionRequest implements ToXContentObject { + protected String configId; + // time range start and end. Unit: epoch milliseconds + protected long start; + protected long end; + + public ResultRequest(StreamInput in) throws IOException { + super(in); + configId = in.readString(); + start = in.readLong(); + end = in.readLong(); + } + + public ResultRequest(String configID, long start, long end) { + super(); + this.configId = configID; + this.start = start; + this.end = end; + } + + public long getStart() { + return start; + } + + public void setStart(long start) { + this.start = start; + } + + public long getEnd() { + return end; + } + + public String getConfigId() { + return configId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configId); + out.writeLong(start); + out.writeLong(end); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java b/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java new file mode 100644 index 000000000..38e566f3d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java @@ -0,0 +1,101 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; + +public abstract class ResultResponse extends ActionResponse implements ToXContentObject { + + protected String error; + protected List features; + protected Long rcfTotalUpdates; + protected Long configIntervalInMinutes; + protected Boolean isHC; + protected String taskId; + + public ResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + String taskId + ) { + this.error = error; + this.features = features; + this.rcfTotalUpdates = rcfTotalUpdates; + this.configIntervalInMinutes = configInterval; + this.isHC = isHC; + this.taskId = taskId; + } + + /** + * Leave it as implementation detail in subclass as how to deserialize TimeSeriesResultResponse + * @param in deserialization stream + * @throws IOException when deserialization errs + */ + public ResultResponse(StreamInput in) throws IOException { + super(in); + } + + public String getError() { + return error; + } + + public List getFeatures() { + return features; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public Long getConfigIntervalInMinutes() { + return configIntervalInMinutes; + } + + public Boolean isHC() { + return isHC; + } + + public String getTaskId() { + return taskId; + } + + /** + * + * @return whether we should save the response to result index + */ + public boolean shouldSave() { + return error != null; + } + + public abstract List toIndexableResults( + String configId, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + Integer schemaVersion, + User user, + String error + ); +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoRequest.java similarity index 80% rename from src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java rename to src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoRequest.java index 8289619c1..1592dd594 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -18,18 +18,18 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -public class SearchAnomalyDetectorInfoRequest extends ActionRequest { +public class SearchConfigInfoRequest extends ActionRequest { private String name; private String rawPath; - public SearchAnomalyDetectorInfoRequest(StreamInput in) throws IOException { + public SearchConfigInfoRequest(StreamInput in) throws IOException { super(in); name = in.readOptionalString(); rawPath = in.readString(); } - public SearchAnomalyDetectorInfoRequest(String name, String rawPath) throws IOException { + public SearchConfigInfoRequest(String name, String rawPath) throws IOException { super(); this.name = name; this.rawPath = rawPath; diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoResponse.java similarity index 83% rename from src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java rename to src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoResponse.java index 852c39d1a..67b44953e 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/SearchConfigInfoResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -20,17 +20,17 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.util.RestHandlerUtils; -public class SearchAnomalyDetectorInfoResponse extends ActionResponse implements ToXContentObject { +public class SearchConfigInfoResponse extends ActionResponse implements ToXContentObject { private long count; private boolean nameExists; - public SearchAnomalyDetectorInfoResponse(StreamInput in) throws IOException { + public SearchConfigInfoResponse(StreamInput in) throws IOException { super(in); count = in.readLong(); nameExists = in.readBoolean(); } - public SearchAnomalyDetectorInfoResponse(long count, boolean nameExists) { + public SearchConfigInfoResponse(long count, boolean nameExists) { this.count = count; this.nameExists = nameExists; } diff --git a/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java new file mode 100644 index 000000000..4028e1565 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; + +public class SingleStreamResultRequest extends ActionRequest implements ToXContentObject { + private final String configId; + private final String modelId; + + // data start/end time epoch in milliseconds + private final long startMillis; + private final long endMillis; + private final double[] datapoint; + private final String taskId; + + public SingleStreamResultRequest(String configId, String modelId, long start, long end, double[] datapoint, String taskId) { + super(); + this.configId = configId; + this.modelId = modelId; + this.startMillis = start; + this.endMillis = end; + this.datapoint = datapoint; + this.taskId = taskId; + } + + public SingleStreamResultRequest(StreamInput in) throws IOException { + super(in); + this.configId = in.readString(); + this.modelId = in.readString(); + this.startMillis = in.readLong(); + this.endMillis = in.readLong(); + this.datapoint = in.readDoubleArray(); + this.taskId = in.readOptionalString(); + } + + public String getConfigId() { + return this.configId; + } + + public String getModelId() { + return modelId; + } + + public long getStart() { + return this.startMillis; + } + + public long getEnd() { + return this.endMillis; + } + + public double[] getDataPoint() { + return this.datapoint; + } + + public String getTaskId() { + return taskId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.configId); + out.writeString(this.modelId); + out.writeLong(this.startMillis); + out.writeLong(this.endMillis); + out.writeDoubleArray(datapoint); + out.writeOptionalString(this.taskId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CommonName.CONFIG_ID_KEY, configId); + builder.field(CommonName.MODEL_ID_KEY, modelId); + builder.field(CommonName.START_JSON_KEY, startMillis); + builder.field(CommonName.END_JSON_KEY, endMillis); + builder.array(CommonName.VALUE_LIST_FIELD, datapoint); + builder.field(CommonName.RUN_ONCE_FIELD, taskId); + builder.endObject(); + return builder; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(configId)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); + } + if (Strings.isEmpty(modelId)) { + validationException = addValidationError(CommonMessages.MODEL_ID_MISSING_MSG, validationException); + } + if (startMillis <= 0 || endMillis <= 0 || startMillis > endMillis) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", CommonMessages.INVALID_TIMESTAMP_ERR_MSG, startMillis, endMillis), + validationException + ); + } + return validationException; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/StatsNodeRequest.java similarity index 69% rename from src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/StatsNodeRequest.java index 099bc7db1..a61135d1a 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsNodeRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -18,21 +18,21 @@ import org.opensearch.transport.TransportRequest; /** - * ADStatsNodeRequest to get a nodes stat + * StatsNodeRequest to get a nodes stat */ -public class ADStatsNodeRequest extends TransportRequest { - private ADStatsRequest request; +public class StatsNodeRequest extends TransportRequest { + private StatsRequest request; /** * Constructor */ - public ADStatsNodeRequest() { + public StatsNodeRequest() { super(); } - public ADStatsNodeRequest(StreamInput in) throws IOException { + public StatsNodeRequest(StreamInput in) throws IOException { super(in); - this.request = new ADStatsRequest(in); + this.request = new StatsRequest(in); } /** @@ -40,7 +40,7 @@ public ADStatsNodeRequest(StreamInput in) throws IOException { * * @param request ADStatsRequest */ - public ADStatsNodeRequest(ADStatsRequest request) { + public StatsNodeRequest(StatsRequest request) { this.request = request; } @@ -49,7 +49,7 @@ public ADStatsNodeRequest(ADStatsRequest request) { * * @return ADStatsRequest for this node */ - public ADStatsRequest getADStatsRequest() { + public StatsRequest getADStatsRequest() { return request; } diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsNodeResponse.java similarity index 85% rename from src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsNodeResponse.java index f5296cf17..a1b5b180f 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsNodeResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Map; @@ -24,7 +24,7 @@ /** * ADStatsNodeResponse */ -public class ADStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { +public class StatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { private Map statsMap; @@ -34,7 +34,7 @@ public class ADStatsNodeResponse extends BaseNodeResponse implements ToXContentF * @param in StreamInput * @throws IOException throws an IO exception if the StreamInput cannot be read from */ - public ADStatsNodeResponse(StreamInput in) throws IOException { + public StatsNodeResponse(StreamInput in) throws IOException { super(in); this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); } @@ -45,7 +45,7 @@ public ADStatsNodeResponse(StreamInput in) throws IOException { * @param node node * @param statsToValues Mapping of stat name to value */ - public ADStatsNodeResponse(DiscoveryNode node, Map statsToValues) { + public StatsNodeResponse(DiscoveryNode node, Map statsToValues) { super(node); this.statsMap = statsToValues; } @@ -57,9 +57,9 @@ public ADStatsNodeResponse(DiscoveryNode node, Map statsToValues * @return ADStatsNodeResponse object corresponding to the input stream * @throws IOException throws an IO exception if the StreamInput cannot be read from */ - public static ADStatsNodeResponse readStats(StreamInput in) throws IOException { + public static StatsNodeResponse readStats(StreamInput in) throws IOException { - return new ADStatsNodeResponse(in); + return new StatsNodeResponse(in); } /** diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsNodesResponse.java similarity index 66% rename from src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsNodesResponse.java index 2dbdff03c..7a8ff9901 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsNodesResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; @@ -24,9 +24,9 @@ import org.opensearch.core.xcontent.XContentBuilder; /** - * ADStatsNodesResponse consists of the aggregated responses from the nodes + * StatsNodesResponse consists of the aggregated responses from the nodes */ -public class ADStatsNodesResponse extends BaseNodesResponse implements ToXContentObject { +public class StatsNodesResponse extends BaseNodesResponse implements ToXContentObject { private static final String NODES_KEY = "nodes"; @@ -36,18 +36,18 @@ public class ADStatsNodesResponse extends BaseNodesResponse * @param in StreamInput * @throws IOException thrown when unable to read from stream */ - public ADStatsNodesResponse(StreamInput in) throws IOException { - super(new ClusterName(in), in.readList(ADStatsNodeResponse::readStats), in.readList(FailedNodeException::new)); + public StatsNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(StatsNodeResponse::readStats), in.readList(FailedNodeException::new)); } /** * Constructor * * @param clusterName name of cluster - * @param nodes List of ADStatsNodeResponses from nodes + * @param nodes List of StatsNodeResponse from nodes * @param failures List of failures from nodes */ - public ADStatsNodesResponse(ClusterName clusterName, List nodes, List failures) { + public StatsNodesResponse(ClusterName clusterName, List nodes, List failures) { super(clusterName, nodes, failures); } @@ -57,13 +57,13 @@ public void writeTo(StreamOutput out) throws IOException { } @Override - public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { out.writeList(nodes); } @Override - public List readNodesFrom(StreamInput in) throws IOException { - return in.readList(ADStatsNodeResponse::readStats); + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(StatsNodeResponse::readStats); } @Override @@ -71,7 +71,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws String nodeId; DiscoveryNode node; builder.startObject(NODES_KEY); - for (ADStatsNodeResponse adStats : getNodes()) { + for (StatsNodeResponse adStats : getNodes()) { node = adStats.getNode(); nodeId = node.getId(); builder.startObject(nodeId); diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java b/src/main/java/org/opensearch/timeseries/transport/StatsRequest.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/ADStatsRequest.java rename to src/main/java/org/opensearch/timeseries/transport/StatsRequest.java index 32301e526..f8b5a8896 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.HashSet; @@ -21,9 +21,9 @@ import org.opensearch.core.common.io.stream.StreamOutput; /** - * ADStatsRequest implements a request to obtain stats about the AD plugin + * StatsRequest implements a request to obtain stats about the time series analytics plugin */ -public class ADStatsRequest extends BaseNodesRequest { +public class StatsRequest extends BaseNodesRequest { /** * Key indicating all stats should be retrieved @@ -32,7 +32,7 @@ public class ADStatsRequest extends BaseNodesRequest { private Set statsToBeRetrieved; - public ADStatsRequest(StreamInput in) throws IOException { + public StatsRequest(StreamInput in) throws IOException { super(in); statsToBeRetrieved = in.readSet(StreamInput::readString); } @@ -42,7 +42,7 @@ public ADStatsRequest(StreamInput in) throws IOException { * * @param nodeIds nodeIds of nodes' stats to be retrieved */ - public ADStatsRequest(String... nodeIds) { + public StatsRequest(String... nodeIds) { super(nodeIds); statsToBeRetrieved = new HashSet<>(); } @@ -52,7 +52,7 @@ public ADStatsRequest(String... nodeIds) { * * @param nodes nodes of nodes' stats to be retrieved */ - public ADStatsRequest(DiscoveryNode... nodes) { + public StatsRequest(DiscoveryNode... nodes) { super(nodes); statsToBeRetrieved = new HashSet<>(); } diff --git a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsResponse.java similarity index 60% rename from src/main/java/org/opensearch/ad/stats/ADStatsResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsResponse.java index f90e451f9..414951e19 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.Map; @@ -17,19 +17,18 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.ad.model.Mergeable; -import org.opensearch.ad.transport.ADStatsNodesResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.Mergeable; /** - * ADStatsResponse contains logic to merge the node stats and cluster stats together and return them to user + * StatsResponse contains logic to merge the node stats and cluster stats together and return them to user */ -public class ADStatsResponse implements ToXContentObject, Mergeable { - private ADStatsNodesResponse adStatsNodesResponse; +public class StatsResponse implements ToXContentObject, Mergeable { + private StatsNodesResponse statsNodesResponse; private Map clusterStats; /** @@ -53,23 +52,23 @@ public void setClusterStats(Map clusterStats) { /** * Get cluster stats * - * @return ADStatsNodesResponse + * @return StatsNodesResponse */ - public ADStatsNodesResponse getADStatsNodesResponse() { - return adStatsNodesResponse; + public StatsNodesResponse getStatsNodesResponse() { + return statsNodesResponse; } /** - * Sets adStatsNodesResponse + * Sets statsNodesResponse * - * @param adStatsNodesResponse AD Stats Response from Nodes + * @param statsNodesResponse Stats Response from Nodes */ - public void setADStatsNodesResponse(ADStatsNodesResponse adStatsNodesResponse) { - this.adStatsNodesResponse = adStatsNodesResponse; + public void setStatsNodesResponse(StatsNodesResponse statsNodesResponse) { + this.statsNodesResponse = statsNodesResponse; } /** - * Convert ADStatsResponse to XContent + * Convert StatsResponse to XContent * * @param builder XContentBuilder * @return XContentBuilder @@ -79,15 +78,15 @@ public XContentBuilder toXContent(XContentBuilder builder) throws IOException { return toXContent(builder, ToXContent.EMPTY_PARAMS); } - public ADStatsResponse() {} + public StatsResponse() {} - public ADStatsResponse(StreamInput in) throws IOException { - adStatsNodesResponse = new ADStatsNodesResponse(in); + public StatsResponse(StreamInput in) throws IOException { + statsNodesResponse = new StatsNodesResponse(in); clusterStats = in.readMap(); } public void writeTo(StreamOutput out) throws IOException { - adStatsNodesResponse.writeTo(out); + statsNodesResponse.writeTo(out); out.writeMap(clusterStats); } @@ -97,7 +96,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws for (Map.Entry clusterStat : clusterStats.entrySet()) { builder.field(clusterStat.getKey(), clusterStat.getValue()); } - adStatsNodesResponse.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + statsNodesResponse.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); return xContentBuilder.endObject(); } @@ -107,10 +106,10 @@ public void merge(Mergeable other) { return; } - ADStatsResponse otherResponse = (ADStatsResponse) other; + StatsResponse otherResponse = (StatsResponse) other; - if (otherResponse.adStatsNodesResponse != null) { - this.adStatsNodesResponse = otherResponse.adStatsNodesResponse; + if (otherResponse.statsNodesResponse != null) { + this.statsNodesResponse = otherResponse.statsNodesResponse; } if (otherResponse.clusterStats != null) { @@ -127,23 +126,17 @@ public boolean equals(Object obj) { if (getClass() != obj.getClass()) return false; - ADStatsResponse other = (ADStatsResponse) obj; - return new EqualsBuilder() - .append(adStatsNodesResponse, other.adStatsNodesResponse) - .append(clusterStats, other.clusterStats) - .isEquals(); + StatsResponse other = (StatsResponse) obj; + return new EqualsBuilder().append(statsNodesResponse, other.statsNodesResponse).append(clusterStats, other.clusterStats).isEquals(); } @Override public int hashCode() { - return new HashCodeBuilder().append(adStatsNodesResponse).append(clusterStats).toHashCode(); + return new HashCodeBuilder().append(statsNodesResponse).append(clusterStats).toHashCode(); } @Override public String toString() { - return new ToStringBuilder(this) - .append("adStatsNodesResponse", adStatsNodesResponse) - .append("clusterStats", clusterStats) - .toString(); + return new ToStringBuilder(this).append("statsNodesResponse", statsNodesResponse).append("clusterStats", clusterStats).toString(); } } diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java b/src/main/java/org/opensearch/timeseries/transport/StatsTimeSeriesResponse.java similarity index 57% rename from src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StatsTimeSeriesResponse.java index c3a108454..ebde3a5f0 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StatsTimeSeriesResponse.java @@ -9,41 +9,40 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -public class StatsAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { - private ADStatsResponse adStatsResponse; +public class StatsTimeSeriesResponse extends ActionResponse implements ToXContentObject { + private StatsResponse statsResponse; - public StatsAnomalyDetectorResponse(StreamInput in) throws IOException { + public StatsTimeSeriesResponse(StreamInput in) throws IOException { super(in); - adStatsResponse = new ADStatsResponse(in); + statsResponse = new StatsResponse(in); } - public StatsAnomalyDetectorResponse(ADStatsResponse adStatsResponse) { - this.adStatsResponse = adStatsResponse; + public StatsTimeSeriesResponse(StatsResponse adStatsResponse) { + this.statsResponse = adStatsResponse; } @Override public void writeTo(StreamOutput out) throws IOException { - adStatsResponse.writeTo(out); + statsResponse.writeTo(out); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - adStatsResponse.toXContent(builder, params); + statsResponse.toXContent(builder, params); return builder; } - protected ADStatsResponse getAdStatsResponse() { - return adStatsResponse; + public StatsResponse getAdStatsResponse() { + return statsResponse; } } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java similarity index 64% rename from src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java index 71563a2cd..da70786a3 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -19,8 +19,6 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -28,43 +26,45 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; -public class StopDetectorRequest extends ActionRequest implements ToXContentObject { +public class StopConfigRequest extends ActionRequest implements ToXContentObject { - private String adID; + private String configID; - public StopDetectorRequest() {} + public StopConfigRequest() {} - public StopDetectorRequest(StreamInput in) throws IOException { + public StopConfigRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } - public StopDetectorRequest(String adID) { + public StopConfigRequest(String configID) { super(); - this.adID = adID; + this.configID = configID; } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } - public StopDetectorRequest adID(String adID) { - this.adID = adID; + public StopConfigRequest adID(String configID) { + this.configID = configID; return this; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } @@ -72,20 +72,20 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.endObject(); return builder; } - public static StopDetectorRequest fromActionRequest(final ActionRequest actionRequest) { - if (actionRequest instanceof StopDetectorRequest) { - return (StopDetectorRequest) actionRequest; + public static StopConfigRequest fromActionRequest(final ActionRequest actionRequest) { + if (actionRequest instanceof StopConfigRequest) { + return (StopConfigRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new StopDetectorRequest(input); + return new StopConfigRequest(input); } } catch (IOException e) { throw new IllegalArgumentException("failed to parse ActionRequest into StopDetectorRequest", e); diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java b/src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java similarity index 78% rename from src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java index 00ca68649..d5ab03781 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -23,15 +23,15 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -public class StopDetectorResponse extends ActionResponse implements ToXContentObject { +public class StopConfigResponse extends ActionResponse implements ToXContentObject { public static final String SUCCESS_JSON_KEY = "success"; private boolean success; - public StopDetectorResponse(boolean success) { + public StopConfigResponse(boolean success) { this.success = success; } - public StopDetectorResponse(StreamInput in) throws IOException { + public StopConfigResponse(StreamInput in) throws IOException { super(in); success = in.readBoolean(); } @@ -53,15 +53,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static StopDetectorResponse fromActionResponse(final ActionResponse actionResponse) { - if (actionResponse instanceof StopDetectorResponse) { - return (StopDetectorResponse) actionResponse; + public static StopConfigResponse fromActionResponse(final ActionResponse actionResponse) { + if (actionResponse instanceof StopConfigResponse) { + return (StopConfigResponse) actionResponse; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new StopDetectorResponse(input); + return new StopConfigResponse(input); } } catch (IOException e) { throw new IllegalArgumentException("failed to parse ActionResponse into StopDetectorResponse", e); diff --git a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java new file mode 100644 index 000000000..3c7b9f45a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamRequest.java @@ -0,0 +1,80 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Config; + +public class SuggestConfigParamRequest extends ActionRequest { + + private final AnalysisType context; + private final Config config; + private final String param; + private final TimeValue requestTimeout; + + public SuggestConfigParamRequest(StreamInput in) throws IOException { + super(in); + context = in.readEnum(AnalysisType.class); + if (context.isAD()) { + config = new AnomalyDetector(in); + } else if (context.isForecast()) { + config = new Forecaster(in); + } else { + throw new UnsupportedOperationException("This method is not supported"); + } + + param = in.readString(); + requestTimeout = in.readTimeValue(); + } + + public SuggestConfigParamRequest(AnalysisType context, Config config, String param, TimeValue requestTimeout) { + this.context = context; + this.config = config; + this.param = param; + this.requestTimeout = requestTimeout; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(context); + config.writeTo(out); + out.writeString(param); + out.writeTimeValue(requestTimeout); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + public Config getConfig() { + return config; + } + + public String getParam() { + return param; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamResponse.java b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamResponse.java new file mode 100644 index 000000000..37238aa07 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/SuggestConfigParamResponse.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +public class SuggestConfigParamResponse extends ActionResponse implements ToXContentObject { + public static final String INTERVAL_FIELD = "interval"; + + private final IntervalTimeConfiguration interval; + + public IntervalTimeConfiguration getInterval() { + return interval; + } + + public SuggestConfigParamResponse(IntervalTimeConfiguration interval) { + this.interval = interval; + } + + public SuggestConfigParamResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + this.interval = IntervalTimeConfiguration.readFrom(in); + } else { + this.interval = null; + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (interval != null) { + out.writeBoolean(true); + interval.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field(INTERVAL_FIELD, interval); + + return xContentBuilder.endObject(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java new file mode 100644 index 000000000..0ad5e86e4 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigRequest.java @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Config; + +public class ValidateConfigRequest extends ActionRequest { + + private final AnalysisType context; + private final Config config; + private final String validationType; + private final Integer maxSingleStreamConfigs; + private final Integer maxHCConfigs; + private final Integer maxFeatures; + private final TimeValue requestTimeout; + // added during refactoring for forecasting. It is fine we add a new field + // since the request is handled by the same node. + private final Integer maxCategoricalFields; + + public ValidateConfigRequest(StreamInput in) throws IOException { + super(in); + context = in.readEnum(AnalysisType.class); + if (context.isAD()) { + config = new AnomalyDetector(in); + } else if (context.isForecast()) { + config = new Forecaster(in); + } else { + throw new UnsupportedOperationException("This method is not supported"); + } + + validationType = in.readString(); + maxSingleStreamConfigs = in.readInt(); + maxHCConfigs = in.readInt(); + maxFeatures = in.readInt(); + requestTimeout = in.readTimeValue(); + maxCategoricalFields = in.readInt(); + } + + public ValidateConfigRequest( + AnalysisType context, + Config config, + String validationType, + Integer maxSingleStreamConfigs, + Integer maxHCConfigs, + Integer maxFeatures, + TimeValue requestTimeout, + Integer maxCategoricalFields + ) { + this.context = context; + this.config = config; + this.validationType = validationType; + this.maxSingleStreamConfigs = maxSingleStreamConfigs; + this.maxHCConfigs = maxHCConfigs; + this.maxFeatures = maxFeatures; + this.requestTimeout = requestTimeout; + this.maxCategoricalFields = maxCategoricalFields; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(context); + config.writeTo(out); + out.writeString(validationType); + out.writeInt(maxSingleStreamConfigs); + out.writeInt(maxHCConfigs); + out.writeInt(maxFeatures); + out.writeTimeValue(requestTimeout); + out.writeInt(maxCategoricalFields); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + public Config getConfig() { + return config; + } + + public String getValidationType() { + return validationType; + } + + public Integer getMaxSingleEntityAnomalyDetectors() { + return maxSingleStreamConfigs; + } + + public Integer getMaxMultiEntityAnomalyDetectors() { + return maxHCConfigs; + } + + public Integer getMaxAnomalyFeatures() { + return maxFeatures; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } + + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigResponse.java similarity index 75% rename from src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ValidateConfigResponse.java index d89022241..f3321024e 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ValidateConfigResponse.java @@ -9,33 +9,33 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.ConfigValidationIssue; -public class ValidateAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { - private DetectorValidationIssue issue; +public class ValidateConfigResponse extends ActionResponse implements ToXContentObject { + private ConfigValidationIssue issue; - public DetectorValidationIssue getIssue() { + public ConfigValidationIssue getIssue() { return issue; } - public ValidateAnomalyDetectorResponse(DetectorValidationIssue issue) { + public ValidateConfigResponse(ConfigValidationIssue issue) { this.issue = issue; } - public ValidateAnomalyDetectorResponse(StreamInput in) throws IOException { + public ValidateConfigResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - issue = new DetectorValidationIssue(in); + issue = new ConfigValidationIssue(in); } } diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..628d95d1c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; + +/** + * Different from ResultIndexingHandler and ResultBulkIndexingHandler, this class uses + * customized transport action to bulk index results. These transport action will + * reduce traffic when index memory pressure is high. + * + * + * @param Batch request type + * @param Batch response type + * @param forecasting or AD result index + * @param Index management class + */ +public abstract class IndexMemoryPressureAwareResultHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger LOG = LogManager.getLogger(IndexMemoryPressureAwareResultHandler.class); + + protected final Client client; + protected final IndexManagementType timeSeriesIndices; + + public IndexMemoryPressureAwareResultHandler(Client client, IndexManagementType timeSeriesIndices) { + this.client = client; + this.timeSeriesIndices = timeSeriesIndices; + } + + /** + * Execute the bulk request + * @param currentBulkRequest The bulk request + * @param listener callback after flushing + */ + public void flush(BatchRequestType currentBulkRequest, ActionListener listener) { + try { + // Only create custom result index when creating detector, won’t recreate custom AD result index in realtime + // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin + // recreate it, that may bring confusion. + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices.initDefaultResultIndexDirectly(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Creating result index with mappings call not acknowledged."); + listener.onFailure(new TimeSeriesException("", "Creating result index with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Unexpected error creating result index", exception); + listener.onFailure(exception); + } + })); + } else { + bulk(currentBulkRequest, listener); + } + } catch (Exception e) { + LOG.warn("Error in bulking results", e); + listener.onFailure(e); + } + } + + public abstract void bulk(BatchRequestType currentBulkRequest, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java similarity index 52% rename from src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java rename to src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java index d61fd1794..4cec1c127 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java @@ -9,11 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.transport.handler; +package org.opensearch.timeseries.transport.handler; -import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; import java.util.List; @@ -24,107 +22,130 @@ import org.opensearch.action.bulk.BulkRequestBuilder; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.IndexableResult; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyResultBulkIndexHandler extends AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(AnomalyResultBulkIndexHandler.class); +/** + * + * Utility method to bulk index results + * + */ +public class ResultBulkIndexingHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> + extends ResultIndexingHandler { - private ADIndexManagement anomalyDetectionIndices; + private static final Logger LOG = LogManager.getLogger(ResultBulkIndexingHandler.class); - public AnomalyResultBulkIndexHandler( + public ResultBulkIndexingHandler( Client client, Settings settings, ThreadPool threadPool, + String indexName, + IndexManagementType timeSeriesIndices, ClientUtil clientUtil, IndexUtils indexUtils, ClusterService clusterService, - ADIndexManagement anomalyDetectionIndices + Setting backOffDelaySetting, + Setting maxRetrySetting ) { - super(client, settings, threadPool, ANOMALY_RESULT_INDEX_ALIAS, anomalyDetectionIndices, clientUtil, indexUtils, clusterService); - this.anomalyDetectionIndices = anomalyDetectionIndices; + super( + client, + settings, + threadPool, + indexName, + timeSeriesIndices, + clientUtil, + indexUtils, + clusterService, + backOffDelaySetting, + maxRetrySetting + ); } /** - * Bulk index anomaly results. Create anomaly result index first if it doesn't exist. + * Bulk index results. Create result index first if it doesn't exist. * - * @param resultIndex anomaly result index - * @param anomalyResults anomaly results + * @param resultIndex result index + * @param results results to save + * @param configId Config Id * @param listener action listener */ - public void bulkIndexAnomalyResult(String resultIndex, List anomalyResults, ActionListener listener) { - if (anomalyResults == null || anomalyResults.size() == 0) { + public void bulk(String resultIndex, List results, String configId, ActionListener listener) { + if (results == null || results.size() == 0) { listener.onResponse(null); return; } - String detectorId = anomalyResults.get(0).getConfigId(); + try { if (resultIndex != null) { - // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime + // Only create custom result index when creating detector, won’t recreate custom AD result index in realtime // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin // recreate it, that may bring confusion. - if (!anomalyDetectionIndices.doesIndexExist(resultIndex)) { - throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + resultIndex, true); + if (!timeSeriesIndices.doesIndexExist(resultIndex)) { + throw new EndRunException(configId, CommonMessages.CAN_NOT_FIND_RESULT_INDEX + resultIndex, true); } - if (!anomalyDetectionIndices.isValidResultIndexMapping(resultIndex)) { - throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); + if (!timeSeriesIndices.isValidResultIndexMapping(resultIndex)) { + throw new EndRunException(configId, "wrong index mapping of custom result index", true); } - bulkSaveDetectorResult(resultIndex, anomalyResults, listener); + bulk(resultIndex, results, listener); return; } - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(response -> { + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices.initDefaultResultIndexDirectly(ActionListener.wrap(response -> { if (response.isAcknowledged()) { - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } else { - String error = "Creating anomaly result index with mappings call not acknowledged"; + String error = "Creating result index with mappings call not acknowledged"; LOG.error(error); listener.onFailure(new TimeSeriesException(error)); } }, exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { // It is possible the index has been created while we sending the create request - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } else { listener.onFailure(exception); } })); } else { - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } } catch (TimeSeriesException e) { listener.onFailure(e); } catch (Exception e) { - String error = "Failed to bulk index anomaly result"; + String error = "Failed to bulk index result"; LOG.error(error, e); listener.onFailure(new TimeSeriesException(error, e)); } } - private void bulkSaveDetectorResult(List anomalyResults, ActionListener listener) { - bulkSaveDetectorResult(ANOMALY_RESULT_INDEX_ALIAS, anomalyResults, listener); + private void bulk(List anomalyResults, ActionListener listener) { + bulk(defaultResultIndexName, anomalyResults, listener); } - private void bulkSaveDetectorResult(String resultIndex, List anomalyResults, ActionListener listener) { + private void bulk(String resultIndex, List results, ActionListener listener) { BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); - anomalyResults.forEach(anomalyResult -> { + results.forEach(analysisResult -> { try (XContentBuilder builder = jsonBuilder()) { IndexRequest indexRequest = new IndexRequest(resultIndex) - .source(anomalyResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + .source(analysisResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); bulkRequestBuilder.add(indexRequest); } catch (Exception e) { - String error = "Failed to prepare request to bulk index anomaly results"; + String error = "Failed to prepare request to bulk index results"; LOG.error(error, e); throw new TimeSeriesException(error); } @@ -132,16 +153,15 @@ private void bulkSaveDetectorResult(String resultIndex, List anom client.bulk(bulkRequestBuilder.request(), ActionListener.wrap(r -> { if (r.hasFailures()) { String failureMessage = r.buildFailureMessage(); - LOG.warn("Failed to bulk index AD result " + failureMessage); + LOG.warn("Failed to bulk index result " + failureMessage); listener.onFailure(new TimeSeriesException(failureMessage)); } else { listener.onResponse(r); } }, e -> { - LOG.error("bulk index ad result failed", e); + LOG.error("bulk index result failed", e); listener.onFailure(e); })); } - } diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java similarity index 74% rename from src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java rename to src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java index 9d539f797..01c482903 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java @@ -9,10 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.transport.handler; +package org.opensearch.timeseries.transport.handler; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; import java.util.Iterator; import java.util.Locale; @@ -25,38 +24,40 @@ import org.opensearch.action.bulk.BackoffPolicy; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.BulkUtil; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.BulkUtil; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(AnomalyIndexHandler.class); - static final String FAIL_TO_SAVE_ERR_MSG = "Fail to save %s: "; - static final String SUCCESS_SAVING_MSG = "Succeed in saving %s"; - static final String CANNOT_SAVE_ERR_MSG = "Cannot save %s due to write block."; - static final String RETRY_SAVING_ERR_MSG = "Retry in saving %s: "; +public class ResultIndexingHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger LOG = LogManager.getLogger(ResultIndexingHandler.class); + public static final String FAIL_TO_SAVE_ERR_MSG = "Fail to save %s: "; + public static final String SUCCESS_SAVING_MSG = "Succeed in saving %s"; + public static final String CANNOT_SAVE_ERR_MSG = "Cannot save %s due to write block."; + public static final String RETRY_SAVING_ERR_MSG = "Retry in saving %s: "; protected final Client client; protected final ThreadPool threadPool; protected final BackoffPolicy savingBackoffPolicy; - protected final String indexName; - protected final ADIndexManagement anomalyDetectionIndices; + protected final String defaultResultIndexName; + protected final IndexManagementType timeSeriesIndices; // whether save to a specific doc id or not. False by default. protected boolean fixedDoc; protected final ClientUtil clientUtil; @@ -70,30 +71,28 @@ public class AnomalyIndexHandler { * @param settings accessor for node settings. * @param threadPool used to invoke specific threadpool to execute * @param indexName name of index to save to - * @param anomalyDetectionIndices anomaly detection indices + * @param timeSeriesIndices anomaly detection indices * @param clientUtil client wrapper * @param indexUtils Index util classes * @param clusterService accessor to ES cluster service */ - public AnomalyIndexHandler( + public ResultIndexingHandler( Client client, Settings settings, ThreadPool threadPool, String indexName, - ADIndexManagement anomalyDetectionIndices, + IndexManagementType timeSeriesIndices, ClientUtil clientUtil, IndexUtils indexUtils, - ClusterService clusterService + ClusterService clusterService, + Setting backOffDelaySetting, + Setting maxRetrySetting ) { this.client = client; this.threadPool = threadPool; - this.savingBackoffPolicy = BackoffPolicy - .exponentialBackoff( - AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(settings), - AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(settings) - ); - this.indexName = indexName; - this.anomalyDetectionIndices = anomalyDetectionIndices; + this.savingBackoffPolicy = BackoffPolicy.exponentialBackoff(backOffDelaySetting.get(settings), maxRetrySetting.get(settings)); + this.defaultResultIndexName = indexName; + this.timeSeriesIndices = timeSeriesIndices; this.fixedDoc = false; this.clientUtil = clientUtil; this.indexUtils = indexUtils; @@ -111,8 +110,8 @@ public void setFixedDoc(boolean fixedDoc) { } // TODO: check if user has permission to index. - public void index(T toSave, String detectorId, String customIndexName) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { + public void index(ResultType toSave, String detectorId, String customIndexName) { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.defaultResultIndexName)) { LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); return; } @@ -122,17 +121,17 @@ public void index(T toSave, String detectorId, String customIndexName) { // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin // recreate it, that may bring confusion. - if (!anomalyDetectionIndices.doesIndexExist(customIndexName)) { - throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + customIndexName, true); + if (!timeSeriesIndices.doesIndexExist(customIndexName)) { + throw new EndRunException(detectorId, CommonMessages.CAN_NOT_FIND_RESULT_INDEX + customIndexName, true); } - if (!anomalyDetectionIndices.isValidResultIndexMapping(customIndexName)) { + if (!timeSeriesIndices.isValidResultIndexMapping(customIndexName)) { throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); } save(toSave, detectorId, customIndexName); return; } - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices .initDefaultResultIndexDirectly( ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, detectorId), exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { @@ -141,7 +140,7 @@ public void index(T toSave, String detectorId, String customIndexName) { } else { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), + String.format(Locale.ROOT, "Unexpected error creating index %s", defaultResultIndexName), exception ); } @@ -153,32 +152,32 @@ public void index(T toSave, String detectorId, String customIndexName) { } catch (Exception e) { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Error in saving %s for detector %s", indexName, detectorId), + String.format(Locale.ROOT, "Error in saving %s for detector %s", defaultResultIndexName, detectorId), e ); } } - private void onCreateIndexResponse(CreateIndexResponse response, T toSave, String detectorId) { + private void onCreateIndexResponse(CreateIndexResponse response, ResultType toSave, String detectorId) { if (response.isAcknowledged()) { save(toSave, detectorId); } else { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Creating %s with mappings call not acknowledged.", indexName) + String.format(Locale.ROOT, "Creating %s with mappings call not acknowledged.", defaultResultIndexName) ); } } - protected void save(T toSave, String detectorId) { - save(toSave, detectorId, indexName); + protected void save(ResultType toSave, String detectorId) { + save(toSave, detectorId, defaultResultIndexName); } // TODO: Upgrade custom result index mapping to latest version? // It may bring some issue if we upgrade the custom result index mapping while user is using that index // for other use cases. One easy solution is to tell user only use custom result index for AD plugin. // For the first release of custom result index, it's not a issue. Will leave this to next phase. - protected void save(T toSave, String detectorId, String indexName) { + protected void save(ResultType toSave, String detectorId, String indexName) { try (XContentBuilder builder = jsonBuilder()) { IndexRequest indexRequest = new IndexRequest(indexName).source(toSave.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); if (fixedDoc) { @@ -192,9 +191,9 @@ protected void save(T toSave, String detectorId, String indexName) { } } - void saveIteration(IndexRequest indexRequest, String detectorId, Iterator backoff) { + void saveIteration(IndexRequest indexRequest, String configId, Iterator backoff) { clientUtil.asyncRequest(indexRequest, client::index, ActionListener.wrap(response -> { - LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, detectorId)); + LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, configId)); }, exception -> { // OpenSearch has a thread pool and a queue for write per node. A thread // pool will have N number of workers ready to handle the requests. When a @@ -210,13 +209,13 @@ void saveIteration(IndexRequest indexRequest, String detectorId, Iterator saveIteration(BulkUtil.cloneIndexRequest(indexRequest), detectorId, backoff), + () -> saveIteration(BulkUtil.cloneIndexRequest(indexRequest), configId, backoff), nextDelay, ThreadPool.Names.SAME ); diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/SearchHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/SearchHandler.java new file mode 100644 index 000000000..e4c9a893e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/handler/SearchHandler.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport.handler; + +import static org.opensearch.timeseries.util.ParseUtils.isAdmin; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * Handle general search request, check user role and return search response. + */ +public class SearchHandler { + private final Logger logger = LogManager.getLogger(SearchHandler.class); + private final Client client; + private volatile Boolean filterEnabled; + + public SearchHandler(Settings settings, ClusterService clusterService, Client client, Setting filterByBackendRoleSetting) { + this.client = client; + filterEnabled = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterEnabled = it); + } + + /** + * Validate user role, add backend role filter if filter enabled + * and execute search. + * + * @param request search request + * @param actionListener action listerner + */ + public void search(SearchRequest request, ActionListener actionListener) { + User user = ParseUtils.getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, CommonMessages.FAIL_TO_SEARCH); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + validateRole(request, user, listener); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void validateRole(SearchRequest request, User user, ActionListener listener) { + if (user == null || !filterEnabled || isAdmin(user)) { + // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin + // Case 2: If Security is enabled and filter is disabled, proceed with search as + // user is already authenticated to hit this API. + // case 3: user is admin which means we don't have to check backend role filtering + client.search(request, listener); + } else { + // Security is enabled, filter is enabled and user isn't admin + try { + ParseUtils.addUserBackendRolesFilter(user, request.source()); + logger.debug("Filtering result by " + user.getBackendRoles()); + client.search(request, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + } + +} diff --git a/src/main/java/org/opensearch/ad/util/BulkUtil.java b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java similarity index 96% rename from src/main/java/org/opensearch/ad/util/BulkUtil.java rename to src/main/java/org/opensearch/timeseries/util/BulkUtil.java index b754b1951..c2b275a1f 100644 --- a/src/main/java/org/opensearch/ad/util/BulkUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.ArrayList; import java.util.HashSet; @@ -23,7 +23,6 @@ import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.timeseries.util.ExceptionUtil; public class BulkUtil { private static final Logger logger = LogManager.getLogger(BulkUtil.class); diff --git a/src/main/java/org/opensearch/ad/util/DateUtils.java b/src/main/java/org/opensearch/timeseries/util/DateUtils.java similarity index 96% rename from src/main/java/org/opensearch/ad/util/DateUtils.java rename to src/main/java/org/opensearch/timeseries/util/DateUtils.java index e7cfc21ce..a76fc5bcb 100644 --- a/src/main/java/org/opensearch/ad/util/DateUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/DateUtils.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.time.Duration; import java.time.Instant; diff --git a/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java b/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java index ca3ba4eba..80ffd5c9f 100644 --- a/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java +++ b/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java @@ -17,10 +17,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.timeseries.constant.CommonName; /** * Util class to filter unwanted node types @@ -91,8 +91,8 @@ public boolean test(DiscoveryNode discoveryNode) { return discoveryNode.isDataNode() && discoveryNode .getAttributes() - .getOrDefault(ADCommonName.BOX_TYPE_KEY, ADCommonName.HOT_BOX_TYPE) - .equals(ADCommonName.HOT_BOX_TYPE); + .getOrDefault(CommonName.BOX_TYPE_KEY, CommonName.HOT_BOX_TYPE) + .equals(CommonName.HOT_BOX_TYPE); } } } diff --git a/src/main/java/org/opensearch/ad/util/IndexUtils.java b/src/main/java/org/opensearch/timeseries/util/IndexUtils.java similarity index 89% rename from src/main/java/org/opensearch/ad/util/IndexUtils.java rename to src/main/java/org/opensearch/timeseries/util/IndexUtils.java index c93511849..cf845dc88 100644 --- a/src/main/java/org/opensearch/ad/util/IndexUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/IndexUtils.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.List; import java.util.Locale; @@ -17,7 +17,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.health.ClusterIndexHealth; @@ -25,7 +24,6 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.timeseries.util.ClientUtil; public class IndexUtils { /** @@ -41,28 +39,17 @@ public class IndexUtils { private static final Logger logger = LogManager.getLogger(IndexUtils.class); - private Client client; - private ClientUtil clientUtil; private ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; /** * Inject annotation required by Guice to instantiate EntityResultTransportAction (transitive dependency) * - * @param client Client to make calls to OpenSearch - * @param clientUtil AD Client utility * @param clusterService ES ClusterService * @param indexNameExpressionResolver index name resolver */ @Inject - public IndexUtils( - Client client, - ClientUtil clientUtil, - ClusterService clusterService, - IndexNameExpressionResolver indexNameExpressionResolver - ) { - this.client = client; - this.clientUtil = clientUtil; + public IndexUtils(ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver) { this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; } diff --git a/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java b/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java index 5d0998d27..7dd830435 100644 --- a/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java +++ b/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java @@ -19,8 +19,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.model.Mergeable; import org.opensearch.core.action.ActionListener; +import org.opensearch.timeseries.model.Mergeable; /** * A listener wrapper to help send multiple requests asynchronously and return one final responses together diff --git a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java index 0978a0de5..da1b7729f 100644 --- a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java @@ -11,7 +11,6 @@ package org.opensearch.timeseries.util; -import static org.opensearch.ad.constant.ADCommonName.EPOCH_MILLIS_FORMAT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.search.aggregations.AggregationBuilders.dateRange; import static org.opensearch.search.aggregations.AggregatorFactories.VALID_AGG_NAME; @@ -35,11 +34,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -47,7 +45,9 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.ParsingException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; @@ -69,7 +69,6 @@ import org.opensearch.search.aggregations.bucket.range.DateRangeAggregationBuilder; import org.opensearch.search.aggregations.metrics.Max; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; @@ -302,23 +301,23 @@ public static AggregatorFactories.Builder parseAggregators(XContentParser parser } public static SearchSourceBuilder generateInternalFeatureQuery( - AnomalyDetector detector, + Config config, long startTime, long endTime, NamedXContentRegistry xContentRegistry ) throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .from(startTime) .to(endTime) .format("epoch_millis") .includeLower(true) .includeUpper(false); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(config.getFilterQuery()); SearchSourceBuilder internalSearchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery); - if (detector.getFeatureAttributes() != null) { - for (Feature feature : detector.getFeatureAttributes()) { + if (config.getFeatureAttributes() != null) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), xContentRegistry, @@ -366,7 +365,7 @@ public static SearchSourceBuilder generateColdStartQuery( BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(config.getFilterQuery()); if (entity.isPresent()) { - for (TermQueryBuilder term : entity.get().getTermQueryBuilders()) { + for (TermQueryBuilder term : entity.get().getTermQueryForCustomerIndex()) { internalFilterQuery.filter(term); } } @@ -393,12 +392,12 @@ public static SearchSourceBuilder generateColdStartQuery( /** * Map feature data to its Id and name * @param currentFeature Feature data - * @param detector Detector Config object + * @param config Config object * @return a list of feature data with Id and name */ - public static List getFeatureData(double[] currentFeature, AnomalyDetector detector) { - List featureIds = detector.getEnabledFeatureIds(); - List featureNames = detector.getEnabledFeatureNames(); + public static List getFeatureData(double[] currentFeature, Config config) { + List featureIds = config.getEnabledFeatureIds(); + List featureNames = config.getEnabledFeatureNames(); int featureLen = featureIds.size(); List featureData = new ArrayList<>(); for (int i = 0; i < featureLen; i++) { @@ -425,6 +424,7 @@ public static SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSou } else if (query instanceof BoolQueryBuilder) { ((BoolQueryBuilder) query).filter(boolQueryBuilder); } else { + // e.g., wild card query throw new TimeSeriesException("Search API does not support queries other than BoolQuery"); } return searchSourceBuilder; @@ -444,7 +444,20 @@ public static User getUserContext(Client client) { return User.parse(userStr); } - public static void resolveUserAndExecute( + /** + * run the given function based on given user + * @param Config response type. Can be either GetAnomalyDetectorResponse or GetForecasterResponse + * @param requestedUser requested user + * @param configId config Id + * @param filterByEnabled filter by backend is enabled + * @param listener listener. We didn't provide the generic type of listener and therefore can return anything using the listener. + * @param function Function to execute + * @param client Client to OS. + * @param clusterService Cluster service of OS. + * @param xContentRegistry Used to deserialize the get config response. + * @param configTypeClass the class of the ConfigType, used by the ConfigFactory to parse the correct type of Config + */ + public static void resolveUserAndExecute( User requestedUser, String configId, boolean filterByEnabled, @@ -491,10 +504,10 @@ public static void resolveUserAndExecute( * @param filterByBackendRole filter by backend role or not * @param configTypeClass the class of the ConfigType, used by the ConfigFactory to parse the correct type of Config */ - public static void getConfig( + public static void getConfig( User requestUser, String configId, - ActionListener listener, + ActionListener listener, Consumer function, Client client, ClusterService clusterService, @@ -520,7 +533,7 @@ public static void getConfig( configTypeClass ), exception -> { - logger.error("Failed to get anomaly detector: " + configId, exception); + logger.error("Failed to get config: " + configId, exception); listener.onFailure(exception); } ) @@ -542,6 +555,7 @@ public static void getConfig( * provided the user holds the requisite permissions. * * @param The type of Config to be processed in this method, which extends from the Config base type. + * @param The type of ActionResponse to be used, which extends from the ActionResponse base type. * @param response The GetResponse from the getConfig request. This contains the information about the config that is to be processed. * @param requestUser The User from the request. This user's permissions will be checked to ensure they have access to the config. * @param configId The ID of the config. This is used for logging and error messages. @@ -551,11 +565,11 @@ public static void getConfig( * @param filterByBackendRole A boolean indicating whether to filter by backend role. If true, the user's backend roles will be checked to ensure they have access to the config. * @param configTypeClass The class of the ConfigType, used by the ConfigFactory to parse the correct type of Config. */ - public static void onGetConfigResponse( + public static void onGetConfigResponse( GetResponse response, User requestUser, String configId, - ActionListener listener, + ActionListener listener, Consumer function, NamedXContentRegistry xContentRegistry, boolean filterByBackendRole, @@ -574,13 +588,17 @@ public static void onGetConfigResponse( function.accept(config); } else { logger.debug("User: " + requestUser.getName() + " does not have permissions to access config: " + configId); - listener.onFailure(new TimeSeriesException(CommonMessages.NO_PERMISSION_TO_ACCESS_CONFIG + configId)); + listener + .onFailure( + new OpenSearchStatusException(CommonMessages.NO_PERMISSION_TO_ACCESS_CONFIG + configId, RestStatus.FORBIDDEN) + ); } } catch (Exception e) { - listener.onFailure(new TimeSeriesException(CommonMessages.FAIL_TO_GET_USER_INFO + configId)); + logger.error("Fail to parse user out of config", e); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_GET_USER_INFO + configId, RestStatus.BAD_REQUEST)); } } else { - listener.onFailure(new ResourceNotFoundException(configId, FAIL_TO_FIND_CONFIG_MSG + configId)); + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); } } @@ -596,7 +614,7 @@ public static boolean isAdmin(User user) { return user.getRoles().contains("all_access"); } - private static boolean checkUserPermissions(User requestedUser, User resourceUser, String detectorId) throws Exception { + private static boolean checkUserPermissions(User requestedUser, User resourceUser, String configId) throws Exception { if (resourceUser.getBackendRoles() == null || requestedUser.getBackendRoles() == null) { return false; } @@ -609,8 +627,8 @@ private static boolean checkUserPermissions(User requestedUser, User resourceUse + requestedUser.getName() + " has backend role: " + backendRole - + " permissions to access detector: " - + detectorId + + " permissions to access config: " + + configId ); return true; } @@ -618,20 +636,15 @@ private static boolean checkUserPermissions(User requestedUser, User resourceUse return false; } - public static boolean checkFilterByBackendRoles(User requestedUser, ActionListener listener) { + public static String checkFilterByBackendRoles(User requestedUser) { if (requestedUser == null) { - return false; + return "Filter by backend roles is enabled and User is null"; } if (requestedUser.getBackendRoles().isEmpty()) { - listener - .onFailure( - new TimeSeriesException( - "Filter by backend roles is enabled and User " + requestedUser.getName() + " does not have backend roles configured" - ) - ); - return false; + return String + .format("Filter by backend roles is enabled and User %s does not have backend roles configured", requestedUser.getName()); } - return true; + return null; } /** @@ -651,7 +664,7 @@ public static Optional getLatestDataTime(SearchResponse searchResponse) { /** * Generate batch query request for feature aggregation on given date range. * - * @param detector anomaly detector + * @param config config accessor * @param entity entity * @param startTime start time * @param endTime end time @@ -661,46 +674,46 @@ public static Optional getLatestDataTime(SearchResponse searchResponse) { * @throws TimeSeriesException throw AD exception if no enabled feature */ public static SearchSourceBuilder batchFeatureQuery( - AnomalyDetector detector, + Config config, Entity entity, long startTime, long endTime, NamedXContentRegistry xContentRegistry ) throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .from(startTime) .to(endTime) - .format(EPOCH_MILLIS_FORMAT) + .format(CommonName.EPOCH_MILLIS_FORMAT) .includeLower(true) .includeUpper(false); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(config.getFilterQuery()); - if (detector.isHighCardinality() && entity != null && entity.getAttributes().size() > 0) { + if (config.isHighCardinality() && entity != null && entity.getAttributes().size() > 0) { entity .getAttributes() .entrySet() .forEach(attr -> { internalFilterQuery.filter(new TermQueryBuilder(attr.getKey(), attr.getValue())); }); } - long intervalSeconds = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().getSeconds(); + long intervalSeconds = ((IntervalTimeConfiguration) config.getInterval()).toDuration().getSeconds(); List> sources = new ArrayList<>(); sources .add( new DateHistogramValuesSourceBuilder(CommonName.DATE_HISTOGRAM) - .field(detector.getTimeField()) + .field(config.getTimeField()) .fixedInterval(DateHistogramInterval.seconds((int) intervalSeconds)) ); CompositeAggregationBuilder aggregationBuilder = new CompositeAggregationBuilder(CommonName.FEATURE_AGGS, sources) .size(MAX_BATCH_TASK_PIECE_SIZE); - if (detector.getEnabledFeatureIds().size() == 0) { + if (config.getEnabledFeatureIds().size() == 0) { throw new TimeSeriesException("No enabled feature configured").countedInStats(false); } - for (Feature feature : detector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { if (feature.getEnabled()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), @@ -776,9 +789,9 @@ public static List parseAggregationRequest(XContentParser parser) throws return fieldNames; } - public static List getFeatureFieldNames(AnomalyDetector detector, NamedXContentRegistry xContentRegistry) throws IOException { + public static List getFeatureFieldNames(Config config, NamedXContentRegistry xContentRegistry) throws IOException { List featureFields = new ArrayList<>(); - for (Feature feature : detector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { featureFields.add(getFieldNamesForFeature(feature, xContentRegistry).get(0)); } return featureFields; diff --git a/src/main/java/org/opensearch/timeseries/util/QueryUtil.java b/src/main/java/org/opensearch/timeseries/util/QueryUtil.java new file mode 100644 index 000000000..e98a5d248 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/QueryUtil.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import java.util.Collections; + +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; + +import com.google.common.collect.ImmutableMap; + +public class QueryUtil { + /** + * Generates the painless script to fetch results that have an entity name matching the passed-in category field. + * + * @param categoryField the category field to be used as a source + * @return the painless script used to get all docs with entity name values matching the category field + */ + public static Script getScriptForCategoryField(String categoryField) { + StringBuilder builder = new StringBuilder() + .append("String value = null;") + .append("if (params == null || params._source == null || params._source.entity == null) {") + .append("return \"\"") + .append("}") + .append("for (item in params._source.entity) {") + .append("if (item[\"name\"] == params[\"categoryField\"]) {") + .append("value = item['value'];") + .append("break;") + .append("}") + .append("}") + .append("return value;"); + + // The last argument contains the K/V pair to inject the categoryField value into the script + return new Script( + ScriptType.INLINE, + "painless", + builder.toString(), + Collections.emptyMap(), + ImmutableMap.of("categoryField", categoryField) + ); + } +} diff --git a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java index 45e318aa2..47ba48dba 100644 --- a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import org.apache.commons.lang.ArrayUtils; @@ -45,7 +46,9 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import com.google.common.base.Throwables; @@ -63,11 +66,8 @@ public final class RestHandlerUtils { public static final String _PRIMARY_TERM = "_primary_term"; public static final String IF_PRIMARY_TERM = "if_primary_term"; public static final String REFRESH = "refresh"; - public static final String DETECTOR_ID = "detectorID"; public static final String RESULT_INDEX = "resultIndex"; - public static final String ANOMALY_DETECTOR = "anomaly_detector"; - public static final String ANOMALY_DETECTOR_JOB = "anomaly_detector_job"; - public static final String REALTIME_TASK = "realtime_detection_task"; + public static final String REALTIME_TASK = "realtime_task"; public static final String HISTORICAL_ANALYSIS_TASK = "historical_analysis_task"; public static final String RUN = "_run"; public static final String PREVIEW = "_preview"; @@ -79,16 +79,31 @@ public final class RestHandlerUtils { public static final String COUNT = "count"; public static final String MATCH = "match"; public static final String RESULTS = "results"; - public static final String TOP_ANOMALIES = "_topAnomalies"; public static final String VALIDATE = "_validate"; + public static final String SEARCH = "_search"; public static final ToXContent.MapParams XCONTENT_WITH_TYPE = new ToXContent.MapParams(ImmutableMap.of("with_type", "true")); + public static final String REST_STATUS = "rest_status"; + public static final String RUN_ONCE = "_run_once"; + public static final String SUGGEST = "_suggest"; + public static final String RUN_ONCE_TASK = "run_once_task"; public static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; public static final String[] UI_METADATA_EXCLUDE = new String[] { Config.UI_METADATA_FIELD }; + public static final String NODE_ID = "nodeId"; + public static final String STATS = "stats"; + public static final String STAT = "stat"; + + // AD constants + public static final String DETECTOR_ID = "detectorID"; + public static final String ANOMALY_DETECTOR = "anomaly_detector"; + public static final String ANOMALY_DETECTOR_JOB = "anomaly_detector_job"; + public static final String TOP_ANOMALIES = "_topAnomalies"; + // forecast constants public static final String FORECASTER_ID = "forecasterID"; public static final String FORECASTER = "forecaster"; - public static final String REST_STATUS = "rest_status"; + public static final String FORECASTER_JOB = "forecaster_job"; + public static final String TOP_FORECASTS = "_topForecasts"; private RestHandlerUtils() {} @@ -247,4 +262,32 @@ public static boolean isProperExceptionToReturn(Throwable e) { private static String coalesceToEmpty(@Nullable String s) { return s == null ? "" : s; } + + public static Entity buildEntity(RestRequest request, String detectorId) throws IOException { + if (org.opensearch.core.common.Strings.isEmpty(detectorId)) { + throw new IllegalStateException(CommonMessages.CONFIG_ID_MISSING_MSG); + } + + String entityName = request.param(CommonName.CATEGORICAL_FIELD); + String entityValue = request.param(CommonName.ENTITY_KEY); + + if (entityName != null && entityValue != null) { + // single-stream profile request: + // GET + // _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= + return Entity.createSingleAttributeEntity(entityName, entityValue); + } else if (request.hasContent()) { + /* + * HCAD profile request: GET + * _plugins/_anomaly_detection/detectors//_profile/init_progress { + * "entity": [{ "name": "clientip", "value": "13.24.0.0" }] } + */ + Optional entity = Entity.fromJsonObject(request.contentParser()); + if (entity.isPresent()) { + return entity.get(); + } + } + // not a valid profile request with correct entity information + return null; + } } diff --git a/src/main/java/org/opensearch/timeseries/util/TaskUtil.java b/src/main/java/org/opensearch/timeseries/util/TaskUtil.java new file mode 100644 index 000000000..a92d81043 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/TaskUtil.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; + +import java.util.List; + +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TaskType; + +public class TaskUtil { + public static List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag, AnalysisType analysisType) { + if (analysisType == AnalysisType.FORECAST) { + if (dateRange == null) { + return ForecastTaskType.REALTIME_TASK_TYPES; + } else { + throw new UnsupportedOperationException("Forecasting does not support historical tasks"); + } + } else { + if (dateRange == null) { + return REALTIME_TASK_TYPES; + } else { + if (resetLatestTaskStateFlag) { + // return all task types include HC entity task to make sure we can reset all tasks latest flag + return ALL_HISTORICAL_TASK_TYPES; + } else { + return HISTORICAL_DETECTOR_TASK_TYPES; + } + } + } + + } +} diff --git a/src/main/resources/mappings/anomaly-checkpoint.json b/src/main/resources/mappings/anomaly-checkpoint.json index 5e515a803..162d28c6d 100644 --- a/src/main/resources/mappings/anomaly-checkpoint.json +++ b/src/main/resources/mappings/anomaly-checkpoint.json @@ -1,7 +1,7 @@ { "dynamic": true, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "detectorId": { @@ -35,6 +35,46 @@ }, "modelV2": { "type": "text" + }, + "samples": { + "type": "nested", + "properties": { + "value_list": { + "type": "nested", + "properties": { + "feature_id": { + "type": "keyword" + }, + "data": { + "type": "double" + } + } + }, + "data_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "data_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } + }, + "last_processed_sample": { + "type": "nested", + "properties": { + "value_list": { + "type": "double" + }, + "data_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "data_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } } } } diff --git a/src/main/resources/mappings/anomaly-results.json b/src/main/resources/mappings/anomaly-results.json index 8c377e78e..3fad67ec2 100644 --- a/src/main/resources/mappings/anomaly-results.json +++ b/src/main/resources/mappings/anomaly-results.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 5 + "schema_version": 6 }, "properties": { "detector_id": { @@ -25,6 +25,9 @@ "feature_id": { "type": "keyword" }, + "feature_name": { + "type": "keyword" + }, "data": { "type": "double" } diff --git a/src/main/resources/mappings/config.json b/src/main/resources/mappings/config.json index 7db1e6d08..c64a697e7 100644 --- a/src/main/resources/mappings/config.json +++ b/src/main/resources/mappings/config.json @@ -150,6 +150,23 @@ }, "detector_type": { "type": "keyword" + }, + "forecast_interval": { + "properties": { + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } + } + }, + "horizon": { + "type": "integer" } } } diff --git a/src/main/resources/mappings/forecast-results.json b/src/main/resources/mappings/forecast-results.json index 745d308ad..6e6bbdc92 100644 --- a/src/main/resources/mappings/forecast-results.json +++ b/src/main/resources/mappings/forecast-results.json @@ -1,5 +1,5 @@ { - "dynamic": true, + "dynamic": false, "_meta": { "schema_version": 1 }, @@ -13,6 +13,9 @@ "feature_id": { "type": "keyword" }, + "feature_name": { + "type": "keyword" + }, "data": { "type": "double" } @@ -95,9 +98,6 @@ "task_id": { "type": "keyword" }, - "model_id": { - "type": "keyword" - }, "entity_id": { "type": "keyword" }, diff --git a/src/main/resources/mappings/job.json b/src/main/resources/mappings/job.json index fb26d56d2..5783c701d 100644 --- a/src/main/resources/mappings/job.json +++ b/src/main/resources/mappings/job.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "schema_version": { @@ -100,6 +100,9 @@ } } } + }, + "type": { + "type": "keyword" } } } diff --git a/src/test/java/org/opensearch/StreamInputOutputTests.java b/src/test/java/org/opensearch/StreamInputOutputTests.java index 82ff5cc24..1fff02fa3 100644 --- a/src/test/java/org/opensearch/StreamInputOutputTests.java +++ b/src/test/java/org/opensearch/StreamInputOutputTests.java @@ -26,15 +26,7 @@ import java.util.Set; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileRequest; -import org.opensearch.ad.transport.EntityProfileResponse; -import org.opensearch.ad.transport.EntityResultRequest; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADEntityProfileAction; import org.opensearch.ad.transport.RCFResultResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -42,7 +34,16 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.transport.EntityProfileRequest; +import org.opensearch.timeseries.transport.EntityProfileResponse; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; /** * Put in core package so that we can using Version's package private constructor @@ -98,7 +99,7 @@ private void setUpEntityResultRequest() { entities.put(entity, feature); start = 10L; end = 20L; - entityResultRequest = new EntityResultRequest(detectorId, entities, start, end); + entityResultRequest = new EntityResultRequest(detectorId, entities, start, end, AnalysisType.AD, null); } /** @@ -111,7 +112,7 @@ public void testDeSerializeEntityResultRequest() throws IOException { StreamInput streamInput = output.bytes().streamInput(); EntityResultRequest readRequest = new EntityResultRequest(streamInput); - assertThat(readRequest.getId(), equalTo(detectorId)); + assertThat(readRequest.getConfigId(), equalTo(detectorId)); assertThat(readRequest.getStart(), equalTo(start)); assertThat(readRequest.getEnd(), equalTo(end)); assertTrue(areEqualWithArrayValue(readRequest.getEntities(), entities)); @@ -133,7 +134,7 @@ public void testDeserializeEntityProfileRequest() throws IOException { StreamInput streamInput = output.bytes().streamInput(); EntityProfileRequest readRequest = new EntityProfileRequest(streamInput); - assertThat(readRequest.getAdID(), equalTo(detectorId)); + assertThat(readRequest.getConfigID(), equalTo(detectorId)); assertThat(readRequest.getEntityValue(), equalTo(entity)); assertThat(readRequest.getProfilesToCollect(), equalTo(profilesToCollect)); } @@ -157,7 +158,7 @@ public void testDeserializeEntityProfileResponse() throws IOException { entityProfileResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - EntityProfileResponse readResponse = EntityProfileAction.INSTANCE.getResponseReader().read(streamInput); + EntityProfileResponse readResponse = ADEntityProfileAction.INSTANCE.getResponseReader().read(streamInput); assertThat(readResponse.getModelProfile(), equalTo(entityProfileResponse.getModelProfile())); assertThat(readResponse.getLastActiveMs(), equalTo(entityProfileResponse.getLastActiveMs())); assertThat(readResponse.getTotalUpdates(), equalTo(entityProfileResponse.getTotalUpdates())); diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index aa2f30b02..72b9caf64 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -22,9 +22,9 @@ import static org.mockito.Mockito.when; import java.io.IOException; -import java.time.Clock; import java.util.Arrays; import java.util.Locale; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -54,7 +54,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.rest.RestRequest; @@ -64,6 +63,7 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.util.SecurityClientUtil; @@ -78,7 +78,6 @@ */ public class IndexAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTest { static ThreadPool threadPool; - private ThreadContext threadContext; private String TEXT_FIELD_TYPE = "text"; private IndexAnomalyDetectorActionHandler handler; private ClusterService clusterService; @@ -96,11 +95,11 @@ public class IndexAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTe private Integer maxSingleEntityAnomalyDetectors; private Integer maxMultiEntityAnomalyDetectors; private Integer maxAnomalyFeatures; + private Integer maxCategoricalFields; private Settings settings; private RestRequest.Method method; private ADTaskManager adTaskManager; private SearchFeatureDao searchFeatureDao; - private Clock clock; @BeforeClass public static void beforeClass() { @@ -122,7 +121,6 @@ public void setUp() throws Exception { settings = Settings.EMPTY; clusterService = mock(ClusterService.class); clientMock = spy(new NodeClient(settings, threadPool)); - clock = mock(Clock.class); NodeStateManager nodeStateManager = mock(NodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, settings); transportService = mock(TransportService.class); @@ -149,6 +147,8 @@ public void setUp() throws Exception { maxAnomalyFeatures = 5; + maxCategoricalFields = 2; + method = RestRequest.Method.POST; adTaskManager = mock(ADTaskManager.class); @@ -160,7 +160,6 @@ public void setUp() throws Exception { clientMock, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -171,6 +170,7 @@ public void setUp() throws Exception { maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -188,8 +188,7 @@ public void testThreeCategoricalFields() throws IOException { ); } - @SuppressWarnings("unchecked") - public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { + public void testMoreThanTenThousandSingleEntityDetectors() throws IOException, InterruptedException { SearchResponse mockResponse = mock(SearchResponse.class); int totalHits = 1001; when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); @@ -211,7 +210,6 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -223,6 +221,7 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -231,7 +230,9 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientMock, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -240,14 +241,14 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { String errorMsg = String .format( Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors ); assertTrue(value.getMessage().contains(errorMsg)); } @SuppressWarnings("unchecked") - public void testTextField() throws IOException { + public void testTextField() throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -289,7 +290,6 @@ public void doE client, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -300,6 +300,7 @@ public void doE maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -310,16 +311,18 @@ public void doE ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof Exception); - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + assertTrue(value.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); } @SuppressWarnings("unchecked") - private void testValidTypeTemplate(String filedTypeName) throws IOException { + private void testValidTypeTemplate(String filedTypeName) throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -376,7 +379,6 @@ public void doE clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -387,6 +389,7 @@ public void doE maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -397,7 +400,9 @@ public void doE ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -406,16 +411,16 @@ public void doE assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); } - public void testIpField() throws IOException { + public void testIpField() throws IOException, InterruptedException { testValidTypeTemplate(CommonName.IP_TYPE); } - public void testKeywordField() throws IOException { + public void testKeywordField() throws IOException, InterruptedException { testValidTypeTemplate(CommonName.KEYWORD_TYPE); } @SuppressWarnings("unchecked") - private void testUpdateTemplate(String fieldTypeName) throws IOException { + private void testUpdateTemplate(String fieldTypeName) throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -474,7 +479,6 @@ public void doE clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -485,6 +489,7 @@ public void doE maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, RestRequest.Method.PUT, xContentRegistry(), null, @@ -495,7 +500,9 @@ public void doE ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -503,22 +510,22 @@ public void doE if (fieldTypeName.equals(CommonName.IP_TYPE) || fieldTypeName.equals(CommonName.KEYWORD_TYPE)) { assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); } else { - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + assertTrue(value.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); } } @Ignore - public void testUpdateIpField() throws IOException { + public void testUpdateIpField() throws IOException, InterruptedException { testUpdateTemplate(CommonName.IP_TYPE); } @Ignore - public void testUpdateKeywordField() throws IOException { + public void testUpdateKeywordField() throws IOException, InterruptedException { testUpdateTemplate(CommonName.KEYWORD_TYPE); } @Ignore - public void testUpdateTextField() throws IOException { + public void testUpdateTextField() throws IOException, InterruptedException { testUpdateTemplate(TEXT_FIELD_TYPE); } @@ -558,7 +565,7 @@ public void doE } @SuppressWarnings("unchecked") - public void testMoreThanTenMultiEntityDetectors() throws IOException { + public void testMoreThanTenMultiEntityDetectors() throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); SearchResponse detectorResponse = mock(SearchResponse.class); @@ -580,7 +587,6 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { clientSpy, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -591,6 +597,7 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -599,24 +606,22 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientSpy, times(1)).search(any(SearchRequest.class), any()); verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof IllegalArgumentException); String errorMsg = String - .format( - Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, - maxMultiEntityAnomalyDetectors - ); + .format(Locale.ROOT, IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); assertTrue(value.getMessage().contains(errorMsg)); } @Ignore @SuppressWarnings("unchecked") - public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException { + public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException, InterruptedException { int totalHits = 10; AnomalyDetector existingDetector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, null); GetResponse getDetectorResponse = TestHelpers @@ -668,7 +673,6 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx clientMock, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -679,6 +683,7 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, RestRequest.Method.PUT, xContentRegistry(), null, @@ -687,7 +692,9 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientMock, times(1)).search(any(SearchRequest.class), any()); @@ -695,12 +702,12 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof IllegalArgumentException); - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG)); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG)); } @Ignore @SuppressWarnings("unchecked") - public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException { + public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException, InterruptedException { int totalHits = 10; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); GetResponse getDetectorResponse = TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX); @@ -751,7 +758,6 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx clientMock, clientUtil, transportService, - channel, anomalyDetectionIndices, detectorId, seqNo, @@ -762,6 +768,7 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, RestRequest.Method.PUT, xContentRegistry(), null, @@ -770,7 +777,9 @@ public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOEx Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientMock, times(0)).search(any(SearchRequest.class), any()); diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java index 4873d1501..413892755 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java @@ -23,6 +23,8 @@ import java.time.Clock; import java.util.Arrays; import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -31,13 +33,15 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; @@ -53,6 +57,8 @@ import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -60,9 +66,9 @@ public class ValidateAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTest { - protected AbstractAnomalyDetectorActionHandler handler; + protected ValidateAnomalyDetectorActionHandler handler; protected ClusterService clusterService; - protected ActionListener channel; + protected ActionListener channel; protected TransportService transportService; protected ADIndexManagement anomalyDetectionIndices; protected String detectorId; @@ -74,9 +80,10 @@ public class ValidateAnomalyDetectorActionHandlerTests extends AbstractTimeSerie protected Integer maxSingleEntityAnomalyDetectors; protected Integer maxMultiEntityAnomalyDetectors; protected Integer maxAnomalyFeatures; + protected Integer maxCategoricalFields; protected Settings settings; protected RestRequest.Method method; - protected ADTaskManager adTaskManager; + protected TaskManager adTaskManager; protected SearchFeatureDao searchFeatureDao; protected Clock clock; @@ -116,6 +123,7 @@ public void setUp() throws Exception { maxSingleEntityAnomalyDetectors = 1000; maxMultiEntityAnomalyDetectors = 10; maxAnomalyFeatures = 5; + maxCategoricalFields = 10; method = RestRequest.Method.POST; adTaskManager = mock(ADTaskManager.class); searchFeatureDao = mock(SearchFeatureDao.class); @@ -126,7 +134,7 @@ public void setUp() throws Exception { } @SuppressWarnings("unchecked") - public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOException { + public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOException, InterruptedException { SearchResponse mockResponse = mock(SearchResponse.class); int totalHits = maxSingleEntityAnomalyDetectors + 1; when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); @@ -150,13 +158,13 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc clusterService, clientSpy, clientUtil, - channel, anomalyDetectionIndices, singleEntityDetector, requestTimeout, maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -165,7 +173,9 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc clock, settings ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientSpy, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); @@ -174,14 +184,14 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc String errorMsg = String .format( Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors ); assertTrue(value.getMessage().contains(errorMsg)); } @SuppressWarnings("unchecked") - public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOException { + public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOException, InterruptedException { String field = "a"; AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); @@ -204,13 +214,13 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio clusterService, clientSpy, clientUtil, - channel, anomalyDetectionIndices, detector, requestTimeout, maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry(), null, @@ -219,18 +229,16 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio clock, Settings.EMPTY ); - handler.start(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + handler.start(ActionListener.wrap(r -> inProgressLatch.countDown(), e -> inProgressLatch.countDown())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); verify(clientSpy, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof ValidationException); String errorMsg = String - .format( - Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, - maxMultiEntityAnomalyDetectors - ); + .format(Locale.ROOT, IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); assertTrue(value.getMessage().contains(errorMsg)); } } diff --git a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java index a98eef88d..69fb5176d 100644 --- a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java @@ -19,7 +19,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import java.time.Clock; import java.util.Arrays; import java.util.HashSet; import java.util.Optional; @@ -33,7 +32,6 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.client.Client; @@ -44,6 +42,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.ProfileName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -78,12 +77,12 @@ protected enum ErrorResultStatus { protected TransportService transportService; protected ADTaskManager adTaskManager; - protected static Set stateOnly; - protected static Set stateNError; - protected static Set modelProfile; - protected static Set stateInitProgress; - protected static Set totalInitProgress; - protected static Set initProgressErrorProfile; + protected static Set stateOnly; + protected static Set stateNError; + protected static Set modelProfile; + protected static Set stateInitProgress; + protected static Set totalInitProgress; + protected static Set initProgressErrorProfile; protected static String noFullShingleError = "No full shingle in current detection window"; protected static String stoppedError = @@ -113,32 +112,23 @@ protected enum ErrorResultStatus { protected int detectorIntervalMin; protected GetResponse detectorGetReponse; protected String messaingExceptionError = "blah"; + protected ADTaskProfileRunner taskProfileRunner; @BeforeClass public static void setUpOnce() { - stateOnly = new HashSet(); - stateOnly.add(DetectorProfileName.STATE); - stateNError = new HashSet(); - stateNError.add(DetectorProfileName.ERROR); - stateNError.add(DetectorProfileName.STATE); - stateInitProgress = new HashSet(); - stateInitProgress.add(DetectorProfileName.INIT_PROGRESS); - stateInitProgress.add(DetectorProfileName.STATE); - modelProfile = new HashSet( - Arrays - .asList( - DetectorProfileName.SHINGLE_SIZE, - DetectorProfileName.MODELS, - DetectorProfileName.COORDINATING_NODE, - DetectorProfileName.TOTAL_SIZE_IN_BYTES - ) - ); - totalInitProgress = new HashSet( - Arrays.asList(DetectorProfileName.TOTAL_ENTITIES, DetectorProfileName.INIT_PROGRESS) - ); - initProgressErrorProfile = new HashSet( - Arrays.asList(DetectorProfileName.INIT_PROGRESS, DetectorProfileName.ERROR) + stateOnly = new HashSet(); + stateOnly.add(ProfileName.STATE); + stateNError = new HashSet(); + stateNError.add(ProfileName.ERROR); + stateNError.add(ProfileName.STATE); + stateInitProgress = new HashSet(); + stateInitProgress.add(ProfileName.INIT_PROGRESS); + stateInitProgress.add(ProfileName.STATE); + modelProfile = new HashSet( + Arrays.asList(ProfileName.SHINGLE_SIZE, ProfileName.MODELS, ProfileName.COORDINATING_NODE, ProfileName.TOTAL_SIZE_IN_BYTES) ); + totalInitProgress = new HashSet(Arrays.asList(ProfileName.TOTAL_ENTITIES, ProfileName.INIT_PROGRESS)); + initProgressErrorProfile = new HashSet(Arrays.asList(ProfileName.INIT_PROGRESS, ProfileName.ERROR)); clusterName = "test-cluster-name"; discoveryNode1 = new DiscoveryNode( "nodeName1", @@ -163,7 +153,7 @@ public void setUp() throws Exception { super.setUp(); client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); - Clock clock = mock(Clock.class); + taskProfileRunner = mock(ADTaskProfileRunner.class); nodeFilter = mock(DiscoveryNodeFilterer.class); clusterService = mock(ClusterService.class); @@ -178,7 +168,7 @@ public void setUp() throws Exception { Consumer> function = (Consumer>) args[2]; function.accept(Optional.of(TestHelpers.randomAdTask())); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); detectorIntervalMin = 3; detectorGetReponse = mock(GetResponse.class); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java index ed5be8fb0..106465dc6 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java @@ -23,7 +23,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES; import java.io.IOException; import java.time.Instant; @@ -53,6 +52,7 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; @@ -61,7 +61,6 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -84,6 +83,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.JobRunner; import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; @@ -94,6 +95,7 @@ import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -118,7 +120,9 @@ public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { @Mock private JobExecutionContext context; - private AnomalyDetectorJobRunner runner = AnomalyDetectorJobRunner.getJobRunnerInstance(); + private JobRunner runner = JobRunner.getJobRunnerInstance(); + + private ADJobProcessor adJobProcessor = ADJobProcessor.getInstance(); @Mock private ThreadPool mockedThreadPool; @@ -129,7 +133,7 @@ public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { private Iterator backoff; @Mock - private AnomalyIndexHandler anomalyResultHandler; + private ResultBulkIndexingHandler anomalyResultHandler; @Mock private ADTaskManager adTaskManager; @@ -163,7 +167,7 @@ public static void tearDownAfterClass() { @Before public void setup() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(AnomalyDetectorJobRunner.class); + super.setUpLog4jForJUnit(JobProcessor.class); MockitoAnnotations.initMocks(this); ThreadFactory threadFactory = OpenSearchExecutors.daemonThreadFactory(OpenSearchExecutors.threadName("node1", "test-ad")); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @@ -171,9 +175,9 @@ public void setup() throws Exception { Mockito.doReturn(executorService).when(mockedThreadPool).executor(anyString()); Mockito.doReturn(mockedThreadPool).when(client).threadPool(); Mockito.doReturn(threadContext).when(mockedThreadPool).getThreadContext(); - runner.setThreadPool(mockedThreadPool); - runner.setClient(client); - runner.setAdTaskManager(adTaskManager); + adJobProcessor.setThreadPool(mockedThreadPool); + adJobProcessor.setClient(client); + adJobProcessor.setTaskManager(adTaskManager); Settings settings = Settings .builder() @@ -183,11 +187,11 @@ public void setup() throws Exception { .build(); setUpJobParameter(); - runner.setSettings(settings); + adJobProcessor.registerSettings(settings); anomalyDetectionIndices = mock(ADIndexManagement.class); - runner.setAnomalyDetectionIndices(anomalyDetectionIndices); + adJobProcessor.setIndexManagement(anomalyDetectionIndices); lockService = new LockService(client, clusterService); doReturn(lockService).when(context).getLockService(); @@ -235,7 +239,7 @@ public void setup() throws Exception { listener.onResponse(Optional.of(detector)); return null; }).when(nodeStateManager).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); - runner.setNodeStateManager(nodeStateManager); + adJobProcessor.setNodeStateManager(nodeStateManager); recorder = new ExecuteADResultResponseRecorder( anomalyDetectionIndices, @@ -248,7 +252,7 @@ public void setup() throws Exception { adTaskCacheManager, 32 ); - runner.setExecuteADResultResponseRecorder(recorder); + adJobProcessor.setExecuteResultResponseRecorder(recorder); } @Rule @@ -293,7 +297,7 @@ public void testRunJobWithLockDuration() throws InterruptedException { @Test public void testRunAdJobWithNullLock() { LockModel lock = null; - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, never()).execute(any(), any(), any()); } @@ -301,7 +305,7 @@ public void testRunAdJobWithNullLock() { public void testRunAdJobWithLock() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); } @@ -311,7 +315,7 @@ public void testRunAdJobWithExecuteException() { doThrow(RuntimeException.class).when(client).execute(any(), any(), any()); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); assertTrue(testAppender.containsMessage("Failed to execute AD job")); } @@ -320,8 +324,8 @@ public void testRunAdJobWithExecuteException() { public void testRunAdJobWithEndRunExceptionNow() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -406,7 +410,8 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b Instant.now(), 60L, TestHelpers.randomUser(), - jobParameter.getCustomResultIndex() + jobParameter.getCustomResultIndex(), + AnalysisType.AD ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) ), Collections.emptyMap(), @@ -430,8 +435,8 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b return null; }).when(client).index(any(IndexRequest.class), any()); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -454,8 +459,8 @@ public void testRunAdJobWithEndRunExceptionNowAndGetJobException() { return null; }).when(client).get(any(GetRequest.class), any()); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -488,8 +493,8 @@ public void testRunAdJobWithEndRunExceptionNowAndFailToGetJob() { return null; }).when(client).get(any(), any()); - runner - .handleAdException( + adJobProcessor + .handleException( jobParameter, lockService, lock, @@ -519,10 +524,10 @@ public void testRunAdJobWithEndRunExceptionNotNowAndRetryUntilStop() throws Inte }).when(client).execute(any(), any(), any()); for (int i = 0; i < 3; i++) { - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(i + 1, testAppender.countMessage("EndRunException happened for")); } - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(1, testAppender.countMessage("JobRunner will stop AD job due to EndRunException retry exceeds upper limit")); } @@ -564,7 +569,8 @@ public Instant confirmInitializedSetup() { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -586,7 +592,7 @@ public void testFailtoFindDetector() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -615,7 +621,7 @@ public void testFailtoFindJob() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -638,7 +644,7 @@ public void testEmptyDetector() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -667,7 +673,7 @@ public void testEmptyJob() { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); @@ -756,7 +762,7 @@ public void testMarkResultIndexQueried() throws IOException { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + adJobProcessor.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(nodeStateManager, times(1)).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); @@ -766,7 +772,7 @@ public void testMarkResultIndexQueried() throws IOException { ArgumentCaptor totalUpdates = ArgumentCaptor.forClass(Long.class); verify(adTaskManager, times(1)) .updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), totalUpdates.capture(), any(), any(), any()); - assertEquals(NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); + assertEquals(TimeSeriesSettings.NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); assertEquals(true, adTaskCacheManager.hasQueriedResultIndex(detector.getId())); } } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index cb88bde96..3f5e14f26 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -42,13 +42,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingResponse; import org.opensearch.cluster.ClusterName; @@ -65,8 +59,15 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.RemoteTransportException; @@ -114,7 +115,8 @@ private void setUpClientGet( nodeFilter, requiredSamples, transportService, - adTaskManager + adTaskManager, + taskProfileRunner ); doAnswer(invocation -> { @@ -208,7 +210,7 @@ public void testDetectorNotExist() throws IOException, InterruptedException { public void testDisabledJobIndexTemplate(JobStatus status) throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, status, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.DISABLED).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.DISABLED).build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); runner.profile(detector.getId(), ActionListener.wrap(response -> { @@ -229,10 +231,10 @@ public void testJobDisabled() throws IOException, InterruptedException { testDisabledJobIndexTemplate(JobStatus.DISABLED); } - public void testInitOrRunningStateTemplate(RCFPollingStatus status, DetectorState expectedState) throws IOException, + public void testInitOrRunningStateTemplate(RCFPollingStatus status, ConfigState expectedState) throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, status, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(expectedState).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(expectedState).build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); runner.profile(detector.getId(), ActionListener.wrap(response -> { @@ -250,37 +252,37 @@ public void testInitOrRunningStateTemplate(RCFPollingStatus status, DetectorStat } public void testResultNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INIT_NOT_EXIT, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_NOT_EXIT, ConfigState.INIT); } public void testRemoteResultNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INIT_NOT_EXIT, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INIT_NOT_EXIT, ConfigState.INIT); } public void testCheckpointIndexNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INDEX_NOT_FOUND, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.INDEX_NOT_FOUND, ConfigState.INIT); } public void testRemoteCheckpointIndexNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INDEX_NOT_FOUND, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INDEX_NOT_FOUND, ConfigState.INIT); } public void testResultEmpty() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.EMPTY, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.EMPTY, ConfigState.INIT); } public void testResultGreaterThanZero() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INIT_DONE, DetectorState.RUNNING); + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_DONE, ConfigState.RUNNING); } @SuppressWarnings("unchecked") public void testErrorStateTemplate( RCFPollingStatus initStatus, ErrorResultStatus status, - DetectorState state, + ConfigState state, String error, JobStatus jobStatus, - Set profilesToCollect + Set profilesToCollect ) throws IOException, InterruptedException { ADTask adTask = TestHelpers.randomAdTask(); @@ -291,18 +293,18 @@ public void testErrorStateTemplate( Consumer> function = (Consumer>) args[2]; function.accept(Optional.of(adTask)); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); setUpClientExecuteRCFPollingAction(initStatus); setUpClientGet(DetectorStatus.EXIST, jobStatus, initStatus, status); DetectorProfile.Builder builder = new DetectorProfile.Builder(); - if (profilesToCollect.contains(DetectorProfileName.STATE)) { + if (profilesToCollect.contains(ProfileName.STATE)) { builder.state(state); } - if (profilesToCollect.contains(DetectorProfileName.ERROR)) { + if (profilesToCollect.contains(ProfileName.ERROR)) { builder.error(error); } - DetectorProfile expectedProfile = builder.build(); + ConfigProfile expectedProfile = builder.build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); runner.profile(detector.getId(), ActionListener.wrap(response -> { @@ -322,7 +324,7 @@ public void testErrorStateTemplate( public void testErrorStateTemplate( RCFPollingStatus initStatus, ErrorResultStatus status, - DetectorState state, + ConfigState state, String error, JobStatus jobStatus ) throws IOException, @@ -331,14 +333,14 @@ public void testErrorStateTemplate( } public void testRunningNoError() throws IOException, InterruptedException { - testErrorStateTemplate(RCFPollingStatus.INIT_DONE, ErrorResultStatus.NO_ERROR, DetectorState.RUNNING, null, JobStatus.ENABLED); + testErrorStateTemplate(RCFPollingStatus.INIT_DONE, ErrorResultStatus.NO_ERROR, ConfigState.RUNNING, null, JobStatus.ENABLED); } public void testRunningWithError() throws IOException, InterruptedException { testErrorStateTemplate( RCFPollingStatus.INIT_DONE, ErrorResultStatus.SHINGLE_ERROR, - DetectorState.RUNNING, + ConfigState.RUNNING, noFullShingleError, JobStatus.ENABLED ); @@ -348,7 +350,7 @@ public void testDisabledForStateError() throws IOException, InterruptedException testErrorStateTemplate( RCFPollingStatus.INITTING, ErrorResultStatus.STOPPED_ERROR, - DetectorState.DISABLED, + ConfigState.DISABLED, stoppedError, JobStatus.DISABLED ); @@ -358,7 +360,7 @@ public void testDisabledForStateInit() throws IOException, InterruptedException testErrorStateTemplate( RCFPollingStatus.INITTING, ErrorResultStatus.STOPPED_ERROR, - DetectorState.DISABLED, + ConfigState.DISABLED, stoppedError, JobStatus.DISABLED, stateInitProgress @@ -369,7 +371,7 @@ public void testInitWithError() throws IOException, InterruptedException { testErrorStateTemplate( RCFPollingStatus.EMPTY, ErrorResultStatus.SHINGLE_ERROR, - DetectorState.INIT, + ConfigState.INIT, noFullShingleError, JobStatus.ENABLED ); @@ -448,7 +450,7 @@ private void setUpClientExecuteProfileAction() { listener.onResponse(profileResponse); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); } @@ -541,7 +543,7 @@ public void testProfileModels() throws InterruptedException, IOException { public void testInitProgress() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); // 123 / 128 rounded to 96% InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); @@ -560,7 +562,7 @@ public void testInitProgress() throws IOException, InterruptedException { public void testInitProgressFailImmediately() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.NO_DOC, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); // 123 / 128 rounded to 96% InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); @@ -579,8 +581,8 @@ public void testInitProgressFailImmediately() throws IOException, InterruptedExc public void testInitNoUpdateNoIndex() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder() - .state(DetectorState.INIT) + ConfigProfile expectedProfile = new DetectorProfile.Builder() + .state(ConfigState.INIT) .initProgress(new InitProgressProfile("0%", detectorIntervalMin * requiredSamples, requiredSamples)) .build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -601,8 +603,8 @@ public void testInitNoUpdateNoIndex() throws IOException, InterruptedException { public void testInitNoIndex() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INDEX_NOT_FOUND, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder() - .state(DetectorState.INIT) + ConfigProfile expectedProfile = new DetectorProfile.Builder() + .state(ConfigState.INIT) .initProgress(new InitProgressProfile("0%", 0, requiredSamples)) .build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -624,7 +626,16 @@ public void testInitNoIndex() throws IOException, InterruptedException { public void testInvalidRequiredSamples() { expectThrows( IllegalArgumentException.class, - () -> new AnomalyDetectorProfileRunner(client, clientUtil, xContentRegistry(), nodeFilter, 0, transportService, adTaskManager) + () -> new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + 0, + transportService, + adTaskManager, + taskProfileRunner + ) ); } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java index 020281ac6..7ac70867e 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java @@ -306,7 +306,8 @@ public ToXContentObject[] getConfig(String detectorId, BasicHeader header, boole null, detector.getUser(), detector.getCustomResultIndex(), - detector.getImputationOption() + detector.getImputationOption(), + 0.001 ), detectorJob, historicalAdTask, @@ -639,7 +640,8 @@ protected AnomalyDetector cloneDetector(AnomalyDetector anomalyDetector, String anomalyDetector.getCategoryFields(), null, resultIndex, - anomalyDetector.getImputationOption() + anomalyDetector.getImputationOption(), + 0.001 ); return detector; } diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java index af598c82d..0ec38eb9a 100644 --- a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java @@ -31,17 +31,9 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.EntityProfile; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.EntityState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.ad.transport.EntityProfileAction; -import org.opensearch.ad.transport.EntityProfileResponse; +import org.opensearch.ad.transport.ADEntityProfileAction; import org.opensearch.client.Client; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -58,10 +50,18 @@ import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.EntityState; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.transport.EntityProfileResponse; import org.opensearch.timeseries.util.SecurityClientUtil; public class EntityProfileRunnerTests extends AbstractTimeSeriesTest { @@ -69,7 +69,7 @@ public class EntityProfileRunnerTests extends AbstractTimeSeriesTest { private int detectorIntervalMin; private Client client; private SecurityClientUtil clientUtil; - private EntityProfileRunner runner; + private ADEntityProfileRunner runner; private Set state; private Set initNInfo; private Set model; @@ -139,7 +139,7 @@ public void setUp() throws Exception { }).when(nodeStateManager).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); - runner = new EntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); + runner = new ADEntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -220,7 +220,7 @@ private void setUpExecuteEntityProfileAction(InittedEverResultStatus initted) { listener.onResponse(profileResponseBuilder.build()); return null; - }).when(client).execute(any(EntityProfileAction.class), any(), any()); + }).when(client).execute(any(ADEntityProfileAction.class), any(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -308,7 +308,7 @@ public void testEmptyProfile() throws InterruptedException { assertTrue("Should not reach here", false); inProgressLatch.countDown(); }, exception -> { - assertTrue(exception.getMessage().contains(ADCommonMessages.EMPTY_PROFILES_COLLECT)); + assertTrue(exception.getMessage().contains(CommonMessages.EMPTY_PROFILES_COLLECT)); inProgressLatch.countDown(); })); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); @@ -400,7 +400,7 @@ public void testNotMultiEntityDetector() throws IOException, InterruptedExceptio assertTrue("Should not reach here", false); inProgressLatch.countDown(); }, exception -> { - assertTrue(exception.getMessage().contains(EntityProfileRunner.NOT_HC_DETECTOR_ERR_MSG)); + assertTrue(exception.getMessage().contains(ADEntityProfileRunner.NOT_HC_DETECTOR_ERR_MSG)); inProgressLatch.countDown(); })); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java index d40fa84f8..d2c27a68b 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java @@ -12,11 +12,6 @@ package org.opensearch.ad; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import java.io.IOException; @@ -40,8 +35,6 @@ import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.transport.AnomalyDetectorJobAction; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -57,6 +50,9 @@ import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import com.google.common.collect.ImmutableList; @@ -180,14 +176,14 @@ public List searchADTasks(String detectorId, String parentTaskId, Boolea BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); if (isLatest != null) { - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, isLatest)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, isLatest)); } if (parentTaskId != null) { - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId)); } SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).trackTotalHits(true).size(size); + sourceBuilder.query(query).sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC).trackTotalHits(true).size(size); searchRequest.source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); SearchResponse searchResponse = client().search(searchRequest).actionGet(); Iterator iterator = searchResponse.getHits().iterator(); @@ -224,29 +220,15 @@ public ADTask startHistoricalAnalysis(Instant startTime, Instant endTime) throws AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); return getADTask(response.getId()); } public ADTask startHistoricalAnalysis(String detectorId, Instant startTime, Instant endTime) throws IOException { DateRange dateRange = new DateRange(startTime, endTime); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); return getADTask(response.getId()); } } diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java index 35bf1a29f..588598400 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java @@ -28,6 +28,7 @@ import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Response; @@ -36,6 +37,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; @@ -251,9 +253,9 @@ protected List waitUntilTaskDone(String detectorId) throws InterruptedEx protected List waitUntilTaskReachState(String detectorId, Set targetStates) throws InterruptedException { List results = new ArrayList<>(); int i = 0; - ADTaskProfile adTaskProfile = null; + TaskProfile adTaskProfile = null; // Increase retryTimes if some task can't reach done state - while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getAdTask().getState())) && i < MAX_RETRY_TIMES) { + while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getTask().getState())) && i < MAX_RETRY_TIMES) { try { adTaskProfile = getADTaskProfile(detectorId); } catch (Exception e) { diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java index 05a63e3df..bc20f29c1 100644 --- a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java @@ -18,7 +18,6 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; @@ -48,13 +47,9 @@ import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.AnomalyResultTests; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; import org.opensearch.ad.util.*; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -66,7 +61,12 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; @@ -79,7 +79,7 @@ public class MultiEntityProfileRunnerTests extends AbstractTimeSeriesTest { private int requiredSamples; private AnomalyDetector detector; private String detectorId; - private Set stateNError; + private Set stateNError; private DetectorInternalState.Builder result; private String node1; private String nodeName1; @@ -97,6 +97,7 @@ public class MultiEntityProfileRunnerTests extends AbstractTimeSeriesTest { private Job job; private TransportService transportService; private ADTaskManager adTaskManager; + private ADTaskProfileRunner taskProfileRunner; enum InittedEverResultStatus { INITTED, @@ -119,7 +120,7 @@ public static void tearDownAfterClass() { public void setUp() throws Exception { super.setUp(); client = mock(Client.class); - Clock clock = mock(Clock.class); + taskProfileRunner = mock(ADTaskProfileRunner.class); NodeStateManager nodeStateManager = mock(NodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); nodeFilter = mock(DiscoveryNodeFilterer.class); @@ -137,7 +138,7 @@ public void setUp() throws Exception { function.accept(Optional.of(TestHelpers.randomAdTask())); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); runner = new AnomalyDetectorProfileRunner( client, clientUtil, @@ -145,7 +146,8 @@ public void setUp() throws Exception { nodeFilter, requiredSamples, transportService, - adTaskManager + adTaskManager, + taskProfileRunner ); doAnswer(invocation -> { @@ -165,9 +167,9 @@ public void setUp() throws Exception { return null; }).when(client).get(any(), any()); - stateNError = new HashSet(); - stateNError.add(DetectorProfileName.ERROR); - stateNError.add(DetectorProfileName.STATE); + stateNError = new HashSet(); + stateNError.add(ProfileName.ERROR); + stateNError.add(ProfileName.STATE); } @SuppressWarnings("unchecked") @@ -248,7 +250,7 @@ private void setUpClientExecuteProfileAction(InittedEverResultStatus initted) { listener.onResponse(profileResponse); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); } @@ -285,7 +287,7 @@ public void testInit() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); @@ -302,7 +304,7 @@ public void testRunning() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.RUNNING).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); @@ -323,7 +325,7 @@ public void testResultIndexFinalTruth() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + ConfigProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.RUNNING).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); diff --git a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java index 0f513b502..0c1a6812d 100644 --- a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java +++ b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java @@ -42,13 +42,13 @@ import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.rest.ADRestTestUtils; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.RestHandlerUtils; @@ -435,7 +435,7 @@ private List startAnomalyDetector(Response response, boolean historicalD Map responseMap = entityAsMap(response); String detectorId = (String) responseMap.get("_id"); int version = (int) responseMap.get("_version"); - assertNotEquals("response is missing Id", AnomalyDetector.NO_ID, detectorId); + assertNotEquals("response is missing Id", Config.NO_ID, detectorId); assertTrue("incorrect version", version > 0); Response startDetectorResponse = TestHelpers diff --git a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java index 2c990682a..d741ee20e 100644 --- a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java +++ b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java @@ -19,34 +19,36 @@ import java.time.Instant; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Optional; import java.util.Random; import org.junit.Before; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + public class AbstractCacheTest extends AbstractTimeSeriesTest { protected String modelId1, modelId2, modelId3, modelId4; protected Entity entity1, entity2, entity3, entity4; - protected ModelState modelState1, modelState2, modelState3, modelState4; + protected ModelState modelState1, modelState2, modelState3, modelState4; protected String detectorId; protected AnomalyDetector detector; protected Clock clock; protected Duration detectorDuration; protected float initialPriority; - protected CacheBuffer cacheBuffer; + protected ADCacheBuffer cacheBuffer; protected long memoryPerEntity; protected MemoryTracker memoryTracker; - protected CheckpointWriteWorker checkpointWriteQueue; - protected CheckpointMaintainWorker checkpointMaintainQueue; + protected ADCheckpointWriteWorker checkpointWriteQueue; + protected ADCheckpointMaintainWorker checkpointMaintainQueue; protected Random random; protected int shingleSize; @@ -85,58 +87,70 @@ public void setUp() throws Exception { memoryPerEntity = 81920; memoryTracker = mock(MemoryTracker.class); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); - checkpointMaintainQueue = mock(CheckpointMaintainWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); + checkpointMaintainQueue = mock(ADCheckpointMaintainWorker.class); - cacheBuffer = new CacheBuffer( - 1, + cacheBuffer = new ADCacheBuffer( 1, - memoryPerEntity, - memoryTracker, clock, + memoryTracker, + 1, TimeSeriesSettings.HOURLY_MAINTENANCE, - detectorId, + memoryPerEntity, checkpointWriteQueue, checkpointMaintainQueue, + detectorId, Duration.ofHours(12).toHoursPart() ); initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); - modelState1 = new ModelState<>( - new EntityModel(entity1, new ArrayDeque<>(), null), + modelState1 = new ModelState( + null, modelId1, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity1), + new ArrayDeque<>() ); - modelState2 = new ModelState<>( - new EntityModel(entity2, new ArrayDeque<>(), null), + modelState2 = new ModelState( + null, modelId2, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity2), + new ArrayDeque<>() ); - modelState3 = new ModelState<>( - new EntityModel(entity3, new ArrayDeque<>(), null), + modelState3 = new ModelState( + null, modelId3, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity3), + new ArrayDeque<>() ); - modelState4 = new ModelState<>( - new EntityModel(entity4, new ArrayDeque<>(), null), + modelState4 = new ModelState( + null, modelId4, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity4), + new ArrayDeque<>() ); } } diff --git a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java index 265560ab5..0eb4ac947 100644 --- a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java +++ b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java @@ -22,8 +22,8 @@ import java.util.Optional; import org.mockito.ArgumentCaptor; -import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -69,7 +69,7 @@ public void testRemovalCandidate2() throws InterruptedException { cacheBuffer.put(modelId2, modelState2); cacheBuffer.put(modelId2, modelState2); cacheBuffer.put(modelId4, modelState4); - assertTrue(cacheBuffer.getModel(modelId2).isPresent()); + assertTrue(cacheBuffer.getModelState(modelId2) != null); ArgumentCaptor memoryReleased = ArgumentCaptor.forClass(Long.class); ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); @@ -93,10 +93,10 @@ public void testCanRemove() { String modelId2 = "2"; String modelId3 = "3"; assertTrue(cacheBuffer.dedicatedCacheAvailable()); - assertTrue(!cacheBuffer.canReplaceWithinDetector(100)); + assertTrue(!cacheBuffer.canReplaceWithinConfig(100)); cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); - assertTrue(cacheBuffer.canReplaceWithinDetector(100)); + assertTrue(cacheBuffer.canReplaceWithinConfig(100)); assertTrue(!cacheBuffer.dedicatedCacheAvailable()); assertTrue(!cacheBuffer.canRemove()); cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); @@ -117,7 +117,7 @@ public void testMaintenance() { cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); cacheBuffer.maintenance(); assertEquals(3, cacheBuffer.getActiveEntities()); - assertEquals(3, cacheBuffer.getAllModels().size()); + assertEquals(3, cacheBuffer.getAllModelStates().size()); // the year of 2122, 100 years later to simulate we are gonna remove all cached entries when(clock.instant()).thenReturn(Instant.ofEpochSecond(4814540761L)); cacheBuffer.maintenance(); @@ -167,7 +167,7 @@ public void testMaintainByHourSaveOne() { verify(checkpointMaintainQueue, times(1)).putAll(savedStates.capture()); List toSave = savedStates.getValue(); assertEquals(1, toSave.size()); - assertEquals(modelId1, toSave.get(0).getEntityModelId()); + assertEquals(modelId1, toSave.get(0).getModelId()); } /** diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index 4154687cf..9afefcb37 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -44,11 +45,8 @@ import org.apache.logging.log4j.Logger; import org.junit.Before; import org.mockito.ArgumentCaptor; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -65,15 +63,20 @@ import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + public class PriorityCacheTests extends AbstractCacheTest { private static final Logger LOG = LogManager.getLogger(PriorityCacheTests.class); - EntityCache entityCache; - CheckpointDao checkpoint; - ModelManager modelManager; + ADPriorityCache entityCache; + ADCheckpointDao checkpoint; + ADModelManager modelManager; ClusterService clusterService; Settings settings; @@ -87,9 +90,9 @@ public class PriorityCacheTests extends AbstractCacheTest { public void setUp() throws Exception { super.setUp(); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); clusterService = mock(ClusterService.class); ClusterSettings settings = new ClusterSettings( @@ -115,7 +118,7 @@ public void setUp() throws Exception { threadPool = mock(ThreadPool.class); setUpADThreadPool(threadPool); - EntityCache cache = new PriorityCache( + ADPriorityCache cache = new ADPriorityCache( checkpoint, dedicatedCacheSize, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, @@ -126,14 +129,14 @@ public void setUp() throws Exception { clusterService, TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, - checkpointWriteQueue, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointMaintainQueue, Settings.EMPTY, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + checkpointWriteQueue, + checkpointMaintainQueue ); - CacheProvider cacheProvider = new CacheProvider(); + ADCacheProvider cacheProvider = new ADCacheProvider(); cacheProvider.set(cache); entityCache = cacheProvider.get(); @@ -171,7 +174,7 @@ public void testCacheHit() { memoryTracker = spy(new MemoryTracker(jvmService, modelMaxPercen, clusterService, mock(CircuitBreakerService.class))); - EntityCache cache = new PriorityCache( + ADPriorityCache cache = new ADPriorityCache( checkpoint, dedicatedCacheSize, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, @@ -182,14 +185,14 @@ public void testCacheHit() { clusterService, TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, - checkpointWriteQueue, TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointMaintainQueue, Settings.EMPTY, - AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + checkpointWriteQueue, + checkpointMaintainQueue ); - CacheProvider cacheProvider = new CacheProvider(); + ADCacheProvider cacheProvider = new ADCacheProvider(); cacheProvider.set(cache); entityCache = cacheProvider.get(); @@ -200,13 +203,14 @@ public void testCacheHit() { entityCache.hostIfPossible(detector, modelState1); assertEquals(1, entityCache.getTotalActiveEntities()); assertEquals(1, entityCache.getAllModels().size()); - ModelState hitState = entityCache.get(modelState1.getModelId(), detector); - assertEquals(detectorId, hitState.getId()); - EntityModel model = hitState.getModel(); - assertEquals(false, model.getTrcf().isPresent()); - assertTrue(model.getSamples().isEmpty()); - modelState1.getModel().addSample(point); - assertTrue(Arrays.equals(point, model.getSamples().peek())); + ModelState hitState = entityCache.get(modelState1.getModelId(), detector); + assertEquals(detectorId, hitState.getConfigId()); + Optional model = hitState.getModel(); + assertTrue(model.isEmpty()); + assertTrue(hitState.getSamples().isEmpty()); + Sample sample = new Sample(point, Instant.now(), Instant.now()); + modelState1.addSample(sample); + assertTrue(Arrays.equals(point, hitState.getSamples().peek().getValueList())); ArgumentCaptor memoryConsumed = ArgumentCaptor.forClass(Long.class); ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); @@ -260,12 +264,15 @@ public void testSharedCache() { entityCache.get(modelId3, detector2); } modelState3 = new ModelState<>( - new EntityModel(entity3, new ArrayDeque<>(), null), + null, modelId3, detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity3), + new ArrayDeque<>() ); entityCache.hostIfPossible(detector2, modelState3); @@ -276,12 +283,15 @@ public void testSharedCache() { entityCache.get(modelId4, detector2); } modelState4 = new ModelState<>( - new EntityModel(entity4, new ArrayDeque<>(), null), + null, modelId4, detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity4), + new ArrayDeque<>() ); entityCache.hostIfPossible(detector2, modelState4); assertEquals(2, entityCache.getActiveEntities(detectorId2)); @@ -303,7 +313,7 @@ public void testReplace() { entityCache.hostIfPossible(detector, modelState1); assertEquals(1, entityCache.getActiveEntities(detectorId)); when(memoryTracker.canAllocate(anyLong())).thenReturn(false); - ModelState state = null; + ModelState state = null; for (int i = 0; i < 4; i++) { entityCache.get(modelId2, detector); @@ -366,7 +376,7 @@ public void testClear() { assertEquals(2, entityCache.getTotalActiveEntities()); assertTrue(entityCache.isActive(detectorId, modelId1)); assertEquals(0, entityCache.getTotalUpdates(detectorId)); - modelState1.getModel().addSample(point); + modelState1.addSample(new Sample(point, Instant.now(), Instant.now())); assertEquals(1, entityCache.getTotalUpdates(detectorId)); assertEquals(1, entityCache.getTotalUpdates(detectorId, modelId1)); entityCache.clear(detectorId); @@ -538,21 +548,27 @@ public void testSelectToReplaceInCache() { private void replaceInOtherCacheSetUp() { Entity entity5 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal5"); Entity entity6 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal6"); - ModelState modelState5 = new ModelState<>( - new EntityModel(entity5, new ArrayDeque<>(), null), + ModelState modelState5 = new ModelState<>( + null, entity5.getModelId(detectorId2).get(), detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity5), + new ArrayDeque<>() ); - ModelState modelState6 = new ModelState<>( - new EntityModel(entity6, new ArrayDeque<>(), null), + ModelState modelState6 = new ModelState<>( + null, entity6.getModelId(detectorId2).get(), detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - 0 + 0, + null, + Optional.of(entity6), + new ArrayDeque<>() ); for (int i = 0; i < 3; i++) { @@ -660,7 +676,7 @@ public void testLongDetectorInterval() { String modelId = entity1.getModelId(detectorId).get(); // record last access time 1000 assertTrue(null == entityCache.get(modelId, detector)); - assertEquals(-1, entityCache.getLastActiveMs(detectorId, modelId)); + assertEquals(-1, entityCache.getLastActiveTime(detectorId, modelId)); // 2 hour = 7200 seconds have passed long currentTimeEpoch = 8200; when(clock.instant()).thenReturn(Instant.ofEpochSecond(currentTimeEpoch)); @@ -669,7 +685,7 @@ public void testLongDetectorInterval() { // door keeper still has the record and won't blocks entity state being created entityCache.get(modelId, detector); // * 1000 to convert to milliseconds - assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveMs(detectorId, modelId)); + assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveTime(detectorId, modelId)); } finally { ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, false); } @@ -725,7 +741,7 @@ public void testRemoveEntityModel() { assertTrue(null != entityCache.get(entity2.getModelId(detectorId).get(), detector)); - entityCache.removeEntityModel(detectorId, entity2.getModelId(detectorId).get()); + entityCache.removeModel(detectorId, entity2.getModelId(detectorId).get()); assertTrue(null == entityCache.get(entity2.getModelId(detectorId).get(), detector)); diff --git a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java index 4e721d68e..09cc23bd6 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java @@ -21,6 +21,7 @@ import org.junit.Before; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.caching.PriorityTracker; public class PriorityTrackerTests extends OpenSearchTestCase { Clock clock; diff --git a/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java index 88546e5ce..ba6ee0374 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java @@ -27,7 +27,6 @@ import org.junit.Before; import org.junit.BeforeClass; import org.opensearch.Version; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -38,6 +37,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.gateway.GatewayService; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.ClusterEventListener; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.constant.CommonName; public class ADClusterEventListenerTests extends AbstractTimeSeriesTest { private final String clusterManagerNodeId = "clusterManagerNode"; @@ -45,7 +47,7 @@ public class ADClusterEventListenerTests extends AbstractTimeSeriesTest { private final String clusterName = "multi-node-cluster"; private ClusterService clusterService; - private ADClusterEventListener listener; + private ClusterEventListener listener; private HashRing hashRing; private ClusterState oldClusterState; private ClusterState newClusterState; @@ -66,7 +68,7 @@ public static void tearDownAfterClass() { @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(ADClusterEventListener.class); + super.setUpLog4jForJUnit(ClusterEventListener.class); clusterService = createClusterService(threadPool); hashRing = mock(HashRing.class); @@ -98,7 +100,7 @@ public void setUp() throws Exception { ) .build(); - listener = new ADClusterEventListener(clusterService, hashRing); + listener = new ClusterEventListener(clusterService, hashRing); } @Override @@ -114,12 +116,12 @@ public void tearDown() throws Exception { public void testUnchangedClusterState() { listener.clusterChanged(new ClusterChangedEvent("foo", oldClusterState, oldClusterState)); - assertTrue(!testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(!testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); } public void testIsWarmNode() { HashMap attributesForNode1 = new HashMap<>(); - attributesForNode1.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + attributesForNode1.put(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE); dataNode1 = new DiscoveryNode(dataNode1Id, buildNewFakeTransportAddress(), attributesForNode1, BUILT_IN_ROLES, Version.CURRENT); ClusterState warmNodeClusterState = ClusterState @@ -134,7 +136,7 @@ public void testIsWarmNode() { .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", warmNodeClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NOT_RECOVERED_MSG)); } public void testNotRecovered() { @@ -150,7 +152,7 @@ public void testNotRecovered() { .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", blockedClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NOT_RECOVERED_MSG)); } class ListenerRunnable implements Runnable { @@ -170,7 +172,7 @@ public void testInProgress() { }).when(hashRing).buildCircles(any(), any()); new Thread(new ListenerRunnable()).start(); listener.clusterChanged(new ClusterChangedEvent("bar", newClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.IN_PROGRESS_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.IN_PROGRESS_MSG)); } public void testNodeAdded() { @@ -182,10 +184,10 @@ public void testNodeAdded() { doAnswer(invocation -> Optional.of(clusterManagerNode)) .when(hashRing) - .getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class)); + .getOwningNodeWithSameLocalVersionForRealtime(any(String.class)); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); assertTrue(testAppender.containsMessage("node removed: false, node added: true")); } @@ -203,7 +205,7 @@ public void testNodeRemoved() { .build(); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, twoDataNodeClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); assertTrue(testAppender.containsMessage("node removed: true, node added: true")); } } diff --git a/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java index 66fbd3e78..928a0ddaf 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java @@ -47,6 +47,7 @@ import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.ADDataMigrator; import org.opensearch.timeseries.constant.CommonName; public class ADDataMigratorTests extends ADUnitTestCase { diff --git a/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java index aa5fcc55b..79f1cd26d 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java @@ -13,22 +13,23 @@ import org.opensearch.Version; import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.timeseries.cluster.VersionUtil; public class ADVersionUtilTests extends ADUnitTestCase { public void testParseVersionFromString() { - Version version = ADVersionUtil.fromString("2.1.0.0"); + Version version = VersionUtil.fromString("2.1.0.0"); assertEquals(Version.V_2_1_0, version); - version = ADVersionUtil.fromString("2.1.0"); + version = VersionUtil.fromString("2.1.0"); assertEquals(Version.V_2_1_0, version); } public void testParseVersionFromStringWithNull() { - expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString(null)); + expectThrows(IllegalArgumentException.class, () -> VersionUtil.fromString(null)); } public void testParseVersionFromStringWithWrongFormat() { - expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString("1.1")); + expectThrows(IllegalArgumentException.class, () -> VersionUtil.fromString("1.1")); } } diff --git a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java index 9c2e79236..33ab3958e 100644 --- a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java @@ -24,11 +24,10 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Locale; import org.junit.Before; -import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -36,9 +35,14 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.settings.ForecastSettings; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.ClusterManagerEventListener; +import org.opensearch.timeseries.cluster.HourlyCron; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -69,13 +73,13 @@ public void setUp() throws Exception { checkpointIndexRetentionCancellable = mock(Cancellable.class); when(threadPool.scheduleWithFixedDelay(any(HourlyCron.class), any(TimeValue.class), any(String.class))) .thenReturn(hourlyCancellable); - when(threadPool.scheduleWithFixedDelay(any(ModelCheckpointIndexRetention.class), any(TimeValue.class), any(String.class))) + when(threadPool.scheduleWithFixedDelay(any(BaseModelCheckpointIndexRetention.class), any(TimeValue.class), any(String.class))) .thenReturn(checkpointIndexRetentionCancellable); client = mock(Client.class); clock = mock(Clock.class); clientUtil = mock(ClientUtil.class); HashMap ignoredAttributes = new HashMap(); - ignoredAttributes.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + ignoredAttributes.put(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE); nodeFilter = new DiscoveryNodeFilterer(clusterService); clusterManagerService = new ClusterManagerEventListener( @@ -86,6 +90,7 @@ public void setUp() throws Exception { clientUtil, nodeFilter, AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + ForecastSettings.FORECAST_CHECKPOINT_TTL, Settings.EMPTY ); } @@ -95,7 +100,10 @@ public void testOnOffClusterManager() { assertThat(hourlyCancellable, is(notNullValue())); assertThat(checkpointIndexRetentionCancellable, is(notNullValue())); assertTrue(!clusterManagerService.getHourlyCron().isCancelled()); - assertTrue(!clusterManagerService.getCheckpointIndexRetentionCron().isCancelled()); + List checkpointIndexRetention = clusterManagerService.getCheckpointIndexRetentionCron(); + for (Cancellable cancellable : checkpointIndexRetention) { + assertTrue(!cancellable.isCancelled()); + } clusterManagerService.offClusterManager(); assertThat(clusterManagerService.getCheckpointIndexRetentionCron(), is(nullValue())); assertThat(clusterManagerService.getHourlyCron(), is(nullValue())); diff --git a/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java index a57a4c649..1ad8834e4 100644 --- a/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java +++ b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java @@ -28,6 +28,7 @@ import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.DailyCron; import org.opensearch.timeseries.util.ClientUtil; public class DailyCronTests extends AbstractTimeSeriesTest { diff --git a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java index 69bb38a57..12826aea1 100644 --- a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java @@ -36,8 +36,7 @@ import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -50,6 +49,8 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.plugins.PluginInfo; +import org.opensearch.timeseries.cluster.ADDataMigrator; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -74,7 +75,7 @@ public class HashRingTests extends ADUnitTestCase { private DiscoveryNode localNode; private DiscoveryNode newNode; private DiscoveryNode warmNode; - private ModelManager modelManager; + private ADModelManager modelManager; @Override @Before @@ -86,7 +87,7 @@ public void setUp() throws Exception { newNodeId = "newNode"; newNode = createNode(newNodeId, "127.0.0.2", 9201, emptyMap()); warmNodeId = "warmNode"; - warmNode = createNode(warmNodeId, "127.0.0.3", 9202, ImmutableMap.of(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE)); + warmNode = createNode(warmNodeId, "127.0.0.3", 9202, ImmutableMap.of(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE)); settings = Settings.builder().put(AD_COOLDOWN_MINUTES.getKey(), TimeValue.timeValueSeconds(5)).build(); ClusterSettings clusterSettings = clusterSetting(settings, AD_COOLDOWN_MINUTES); @@ -107,7 +108,7 @@ public void setUp() throws Exception { when(adminClient.cluster()).thenReturn(clusterAdminClient); String modelId = "123_model_threshold"; - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); doAnswer(invocation -> { Set res = new HashSet<>(); res.add(modelId); @@ -121,7 +122,7 @@ public void testGetOwningNodeWithEmptyResult() throws UnknownHostException { DiscoveryNode node1 = createNode(Integer.toString(1), "127.0.0.4", 9204, emptyMap()); doReturn(node1).when(clusterService).localNode(); - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime("http-latency-rcf-1"); assertFalse(node.isPresent()); } @@ -130,10 +131,10 @@ public void testGetOwningNode() throws UnknownHostException { // Add first node, hashRing.buildCircles(delta, ActionListener.wrap(r -> { - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime("http-latency-rcf-1"); assertTrue(node.isPresent()); assertTrue(asList(newNodeId, localNodeId).contains(node.get().getId())); - DiscoveryNode[] nodesWithSameLocalAdVersion = hashRing.getNodesWithSameLocalAdVersion(); + DiscoveryNode[] nodesWithSameLocalAdVersion = hashRing.getNodesWithSameLocalVersion(); Set nodesWithSameLocalAdVersionIds = new HashSet<>(); for (DiscoveryNode n : nodesWithSameLocalAdVersion) { nodesWithSameLocalAdVersionIds.add(n.getId()); @@ -143,10 +144,10 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 2, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); // Circles for realtime AD will change as it's eligible to build for when its empty - assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); assertFalse("Build hash ring failed", true); @@ -162,10 +163,10 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 3, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); // Circles for realtime AD will not change as it's eligible to rebuild - assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); @@ -183,9 +184,9 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 4, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); - assertEquals("Wrong hash ring size for realtime AD", 4, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 4, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); assertFalse("Failed to build hash ring", true); @@ -194,7 +195,7 @@ public void testGetOwningNode() throws UnknownHostException { public void testGetAllEligibleDataNodesWithKnownAdVersionAndGetNodeByAddress() { setupNodeDelta(); - hashRing.getAllEligibleDataNodesWithKnownAdVersion(nodes -> { + hashRing.getAllEligibleDataNodesWithKnownVersion(nodes -> { assertEquals("Wrong hash ring size for historical analysis", 2, nodes.length); Optional node = hashRing.getNodeByAddress(newNode.getAddress()); assertTrue(node.isPresent()); @@ -205,7 +206,7 @@ public void testGetAllEligibleDataNodesWithKnownAdVersionAndGetNodeByAddress() { public void testBuildAndGetOwningNodeWithSameLocalAdVersion() { setupNodeDelta(); hashRing - .buildAndGetOwningNodeWithSameLocalAdVersion( + .buildAndGetOwningNodeWithSameLocalVersion( "testModelId", node -> { assertTrue(node.isPresent()); }, ActionListener.wrap(r -> {}, e -> { diff --git a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java index 2806138d9..831b546da 100644 --- a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java @@ -27,10 +27,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.transport.CronAction; -import org.opensearch.ad.transport.CronNodeResponse; -import org.opensearch.ad.transport.CronResponse; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -39,6 +36,10 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.HourlyCron; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.CronNodeResponse; +import org.opensearch.timeseries.transport.CronResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import test.org.opensearch.ad.util.ClusterCreation; @@ -59,7 +60,7 @@ public void templateHourlyCron(HourlyCronTestExecutionMode mode) { ClusterState state = ClusterCreation.state(1); when(clusterService.state()).thenReturn(state); HashMap ignoredAttributes = new HashMap(); - ignoredAttributes.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + ignoredAttributes.put(CommonName.BOX_TYPE_KEY, CommonName.WARM_BOX_TYPE); DiscoveryNodeFilterer nodeFilter = new DiscoveryNodeFilterer(clusterService); Client client = mock(Client.class); diff --git a/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java index 0748fe122..399da125e 100644 --- a/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java +++ b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java @@ -37,6 +37,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.store.StoreStats; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; import org.opensearch.timeseries.util.ClientUtil; public class IndexCleanupTests extends AbstractTimeSeriesTest { diff --git a/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java b/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java index b95757925..eca572199 100644 --- a/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java +++ b/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java @@ -27,8 +27,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.core.action.ActionListener; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.diskcleanup.BaseModelCheckpointIndexRetention; +import org.opensearch.timeseries.cluster.diskcleanup.IndexCleanup; public class ModelCheckpointIndexRetentionTests extends AbstractTimeSeriesTest { @@ -39,7 +42,7 @@ public class ModelCheckpointIndexRetentionTests extends AbstractTimeSeriesTest { @Mock IndexCleanup indexCleanup; - ModelCheckpointIndexRetention modelCheckpointIndexRetention; + BaseModelCheckpointIndexRetention modelCheckpointIndexRetention; @SuppressWarnings("unchecked") @Before @@ -47,7 +50,12 @@ public void setUp() throws Exception { super.setUp(); super.setUpLog4jForJUnit(IndexCleanup.class); MockitoAnnotations.initMocks(this); - modelCheckpointIndexRetention = new ModelCheckpointIndexRetention(defaultCheckpointTtl, clock, indexCleanup); + modelCheckpointIndexRetention = new BaseModelCheckpointIndexRetention( + defaultCheckpointTtl, + clock, + indexCleanup, + ADIndex.CHECKPOINT.getIndexName() + ); doAnswer(invocation -> { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[2]; diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java index 4330118b6..b74f1ea58 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java @@ -11,9 +11,9 @@ package org.opensearch.ad.e2e; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE; import static org.opensearch.timeseries.TestHelpers.toHttpEntity; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.BACKOFF_MINUTES; -import static org.opensearch.timeseries.settings.TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.File; import java.io.FileReader; @@ -61,8 +61,8 @@ protected void disableResourceNotFoundFaultTolerence() throws IOException { settingCommand.startObject(); settingCommand.startObject("persistent"); - settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); - settingCommand.field(BACKOFF_MINUTES.getKey(), 0); + settingCommand.field(AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); + settingCommand.field(AD_BACKOFF_MINUTES.getKey(), 0); settingCommand.endObject(); settingCommand.endObject(); Request request = new Request("PUT", "/_cluster/settings"); diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java index 8edab0d15..e856cd1cd 100644 --- a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java @@ -27,12 +27,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.core.Logger; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; import com.google.common.collect.ImmutableMap; import com.google.gson.JsonElement; @@ -117,10 +117,7 @@ public void testValidationIntervalRecommendation() throws Exception { @SuppressWarnings("unchecked") Map> messageMap = (Map>) XContentMapValues .extractValue("model", responseMap); - assertEquals( - ADCommonMessages.DETECTOR_INTERVAL_REC + recDetectorIntervalMinutes, - messageMap.get("detection_interval").get("message") - ); + assertEquals(CommonMessages.INTERVAL_REC + recDetectorIntervalMinutes, messageMap.get("detection_interval").get("message")); } public void testValidationWindowDelayRecommendation() throws Exception { @@ -158,7 +155,7 @@ public void testValidationWindowDelayRecommendation() throws Exception { Map> messageMap = (Map>) XContentMapValues .extractValue("model", responseMap); assertEquals( - String.format(Locale.ROOT, ADCommonMessages.WINDOW_DELAY_REC, expectedWindowDelayMinutes, expectedWindowDelayMinutes), + String.format(Locale.ROOT, CommonMessages.WINDOW_DELAY_REC, expectedWindowDelayMinutes, expectedWindowDelayMinutes), messageMap.get("window_delay").get("message") ); } diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index b78647c11..01549b02d 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -61,7 +61,10 @@ import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; @@ -142,8 +145,6 @@ public void setup() { searchFeatureDao, imputer, clock, - maxTrainSamples, - maxSampleStride, trainSampleTimeRangeInHours, minTrainSamples, maxMissingPointsRate, @@ -203,7 +204,7 @@ public void getColdStartData_returnExpectedToListener( ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.ofNullable(latestTime)); return null; - }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); if (latestTime != null) { doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(3); @@ -220,8 +221,6 @@ public void getColdStartData_returnExpectedToListener( searchFeatureDao, imputer, clock, - maxTrainSamples, - maxSampleStride, trainSampleTimeRangeInHours, minTrainSamples, 0.5, /*maxMissingPointsRate*/ @@ -248,7 +247,7 @@ public void getColdStartData_throwToListener_whenSearchFail() { ActionListener> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); ActionListener> listener = mock(ActionListener.class); featureManager.getColdStartData(detector, listener); @@ -263,7 +262,7 @@ public void getColdStartData_throwToListener_onQueryCreationError() throws Excep ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.ofNullable(0L)); return null; - }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + }).when(searchFeatureDao).getLatestDataTime(eq(detector), eq(Optional.empty()), eq(AnalysisType.AD), any(ActionListener.class)); doThrow(IOException.class) .when(searchFeatureDao) .getFeatureSamplesForPeriods(eq(detector), any(), eq(AnalysisType.AD), any(ActionListener.class)); diff --git a/src/test/java/org/opensearch/ad/feature/FeaturesTests.java b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java index 7a6b3b8e1..f7716e81b 100644 --- a/src/test/java/org/opensearch/ad/feature/FeaturesTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.junit.runner.RunWith; +import org.opensearch.timeseries.feature.Features; import junitparams.JUnitParamsRunner; import junitparams.Parameters; diff --git a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java index 53bea9015..ebeded321 100644 --- a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java +++ b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java @@ -142,7 +142,7 @@ private Map createMapping() { entity_nested_mapping.put("name", Collections.singletonMap("type", "keyword")); entity_nested_mapping.put("value", Collections.singletonMap("type", "keyword")); entity_mapping.put(CommonName.PROPERTIES, entity_nested_mapping); - mappings.put(CommonName.ENTITY_FIELD, entity_mapping); + mappings.put(CommonName.ENTITY_KEY, entity_mapping); Map error_mapping = new HashMap<>(); error_mapping.put("type", "text"); @@ -188,7 +188,7 @@ private Map createMapping() { attribution_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); mappings.put(AnomalyResult.RELEVANT_ATTRIBUTION_FIELD, attribution_mapping); - mappings.put(CommonName.SCHEMA_VERSION_FIELD, Collections.singletonMap("type", "integer")); + mappings.put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, Collections.singletonMap("type", "integer")); mappings.put(CommonName.TASK_ID_FIELD, Collections.singletonMap("type", "keyword")); diff --git a/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java index 10b00426a..0a837603d 100644 --- a/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java +++ b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java @@ -55,7 +55,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.timeseries.AbstractTimeSeriesTest; -import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -168,7 +167,7 @@ public void testUpdateMapping() throws IOException { put(ADIndexManagement.META, new HashMap() { { // version 1 will cause update - put(CommonName.SCHEMA_VERSION_FIELD, 1); + put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, 1); } }); } diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java index 1a86e45d4..8c67a57e7 100644 --- a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java @@ -21,6 +21,8 @@ import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -29,9 +31,8 @@ import org.opensearch.Version; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -52,12 +53,16 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ClientUtil; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; public class AbstractCosineDataTest extends AbstractTimeSeriesTest { @@ -65,10 +70,10 @@ public class AbstractCosineDataTest extends AbstractTimeSeriesTest { String modelId; String entityName; String detectorId; - ModelState modelState; + ModelState modelState; Clock clock; float priority; - EntityColdStarter entityColdStarter; + ADEntityColdStart entityColdStarter; NodeStateManager stateManager; SearchFeatureDao searchFeatureDao; Imputer imputer; @@ -78,13 +83,13 @@ public class AbstractCosineDataTest extends AbstractTimeSeriesTest { ThreadPool threadPool; AtomicBoolean released; Runnable releaseSemaphore; - ActionListener listener; + ActionListener>> listener; CountDownLatch inProgressLatch; - CheckpointWriteWorker checkpointWriteQueue; + ADCheckpointWriteWorker checkpointWriteQueue; Entity entity; AnomalyDetector detector; long rcfSeed; - ModelManager modelManager; + ADModelManager modelManager; ClientUtil clientUtil; ClusterService clusterService; ClusterSettings clusterSettings; @@ -153,16 +158,14 @@ public void setUp() throws Exception { imputer = new LinearUniformImputer(true); searchFeatureDao = mock(SearchFeatureDao.class); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); featureManager = new FeatureManager( searchFeatureDao, imputer, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -172,10 +175,10 @@ public void setUp() throws Exception { TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); rcfSeed = 2051L; - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStart( clock, threadPool, stateManager, @@ -185,15 +188,15 @@ public void setUp() throws Exception { numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, + // settings, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, rcfSeed, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 ); detectorId = "123"; @@ -211,8 +214,8 @@ public void setUp() throws Exception { }; listener = ActionListener.wrap(releaseSemaphore); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), TimeSeriesSettings.NUM_TREES, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java index 72358af10..f44c468ec 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java @@ -23,7 +23,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.ad.ml.CheckpointDao.FIELD_MODELV2; +import static org.opensearch.ad.ml.ADCheckpointDao.FIELD_MODELV2; import java.io.BufferedReader; import java.io.File; @@ -40,19 +40,15 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.time.Month; -import java.time.OffsetDateTime; -import java.time.ZoneOffset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Deque; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.NoSuchElementException; import java.util.Optional; -import java.util.Queue; import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -101,7 +97,10 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ClientUtil; @@ -127,7 +126,7 @@ public class CheckpointDaoTests extends OpenSearchTestCase { private static final Logger logger = LogManager.getLogger(CheckpointDaoTests.class); - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; // dependencies @Mock(answer = Answers.RETURNS_DEEP_STUBS) @@ -162,6 +161,7 @@ public class CheckpointDaoTests extends OpenSearchTestCase { private ThresholdedRandomCutForestMapper trcfMapper; private V1JsonToV3StateConverter converter; double anomalyRate; + private Instant now; @Before public void setup() { @@ -174,7 +174,8 @@ public void setup() { thresholdingModelClass = HybridThresholdingModel.class; - when(clock.instant()).thenReturn(Instant.now()); + now = Instant.now(); + when(clock.instant()).thenReturn(now); mapper = new RandomCutForestMapper(); mapper.setSaveExecutorContextEnabled(true); @@ -211,10 +212,9 @@ public PooledObject wrap(LinkedBuffer obj) { serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); anomalyRate = 0.005; - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -225,7 +225,8 @@ public PooledObject wrap(LinkedBuffer obj) { maxCheckpointBytes, serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); @@ -463,10 +464,6 @@ public void test_getModelCheckpoint_returnEmptyToListener_whenModelNotFound() { GetRequest getRequest = requestCaptor.getValue(); assertEquals(indexName, getRequest.index()); assertEquals(modelId, getRequest.id()); - // ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Exception.class); - // verify(listener).onFailure(responseCaptor.capture()); - // Exception exception = responseCaptor.getValue(); - // assertTrue(exception instanceof ResourceNotFoundException); ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); assertTrue(!responseCaptor.getValue().isPresent()); @@ -496,14 +493,15 @@ public void test_deleteModelCheckpoint_callListener_whenCompleted() { @SuppressWarnings("unchecked") public void test_restore() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - EntityModel modelToSave = state.getModel(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ThresholdedRandomCutForest modelToSave = state.getModel().get(); GetResponse getResponse = mock(GetResponse.class); when(getResponse.isExists()).thenReturn(true); Map source = new HashMap<>(); - source.put(CheckpointDao.DETECTOR_ID, state.getId()); - source.put(CheckpointDao.FIELD_MODELV2, checkpointDao.toCheckpoint(modelToSave, modelId).get()); + source.put(ADCheckpointDao.DETECTOR_ID, state.getConfigId()); + source.put(ADCheckpointDao.FIELD_MODELV2, checkpointDao.toCheckpoint(modelToSave, modelId).get()); source.put(CommonName.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); when(getResponse.getSource()).thenReturn(source); @@ -514,30 +512,21 @@ public void test_restore() throws IOException { return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); - ActionListener>> listener = mock(ActionListener.class); - checkpointDao.deserializeModelCheckpoint(modelId, listener); + ModelState modelState = checkpointDao + .processHCGetResponse(getResponse, modelId, ADCheckpointDao.DETECTOR_ID); - ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); - verify(listener).onResponse(responseCaptor.capture()); - Optional> response = responseCaptor.getValue(); - assertTrue(response.isPresent()); - Entry entry = response.get(); - OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); - assertEquals(2020, utcTime.getYear()); - assertEquals(Month.OCTOBER, utcTime.getMonth()); - assertEquals(11, utcTime.getDayOfMonth()); - assertEquals(22, utcTime.getHour()); - assertEquals(58, utcTime.getMinute()); - assertEquals(23, utcTime.getSecond()); - - EntityModel model = entry.getKey(); - Queue queue = model.getSamples(); - Queue samplesToSave = modelToSave.getSamples(); + Instant utcTime = modelState.getLastCheckpointTime(); + // Oct 11, 2020 22:58:23 UTC + assertEquals(Instant.ofEpochSecond(1602457103), utcTime); + + ThresholdedRandomCutForest model = modelState.getModel().get(); + assertEquals(modelToSave.getForest().getTotalUpdates(), model.getForest().getTotalUpdates()); + + Deque queue = modelState.getSamples(); + Deque samplesToSave = state.getSamples(); assertEquals(samplesToSave.size(), queue.size()); - assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); - logger.info(modelToSave.getTrcf()); - logger.info(model.getTrcf()); - assertEquals(modelToSave.getTrcf().get().getForest().getTotalUpdates(), model.getTrcf().get().getForest().getTotalUpdates()); + assertEquals(samplesToSave.peek(), queue.peek()); + } public void test_batch_write_no_index() { @@ -644,7 +633,7 @@ public void test_batch_write_no_init() throws InterruptedException { final CountDownLatch processingLatch = new CountDownLatch(1); checkpointDao - .batchWrite(new BulkRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); + .batchWrite(new BulkRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); // we don't expect the waiting time elapsed before the count reached zero assertTrue(processingLatch.await(100, TimeUnit.SECONDS)); @@ -679,10 +668,9 @@ public void test_batch_read() throws InterruptedException { } public void test_too_large_checkpoint() throws IOException { - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -693,16 +681,19 @@ public void test_too_large_checkpoint() throws IOException { 1, // make the max checkpoint size 1 byte only serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); assertTrue(checkpointDao.toIndexSource(state).isEmpty()); } public void test_to_index_source() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); Map source = checkpointDao.toIndexSource(state); assertTrue(!source.isEmpty()); @@ -716,10 +707,9 @@ public void test_to_index_source() throws IOException { public void testBorrowFromPoolFailure() throws Exception { GenericObjectPool mockSerializeRCFBufferPool = mock(GenericObjectPool.class); when(mockSerializeRCFBufferPool.borrowObject()).thenThrow(NoSuchElementException.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -730,21 +720,22 @@ public void testBorrowFromPoolFailure() throws Exception { 1, // make the max checkpoint size 1 byte only mockSerializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - assertTrue(!checkpointDao.toCheckpoint(state.getModel(), modelId).get().isEmpty()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + assertTrue(!checkpointDao.toCheckpoint(state.getModel().get(), modelId).get().isEmpty()); } public void testMapperFailure() throws IOException { ThresholdedRandomCutForestMapper mockMapper = mock(ThresholdedRandomCutForestMapper.class); when(mockMapper.toState(any())).thenThrow(RuntimeException.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -755,44 +746,47 @@ public void testMapperFailure() throws IOException { 1, // make the max checkpoint size 1 byte only serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); // make sure sample size is not 0 otherwise sample size won't be written to checkpoint - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(1).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - assertEquals(null, JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(1).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); + assertEquals(null, JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); assertTrue(null != JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); // assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); } public void testEmptySample() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); assertEquals(null, JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointErcfCheckoutFail() throws Exception { when(serializeRCFBufferPool.borrowObject()).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } @SuppressWarnings("unchecked") private void setUpMockTrcf() { trcfMapper = mock(ThresholdedRandomCutForestMapper.class); trcfSchema = mock(Schema.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -803,7 +797,8 @@ private void setUpMockTrcf() { maxCheckpointBytes, serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); } @@ -811,10 +806,11 @@ public void testToCheckpointTrcfCheckoutBufferFail() throws Exception { setUpMockTrcf(); when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointTrcfFailNewBuffer() throws Exception { @@ -822,10 +818,11 @@ public void testToCheckpointTrcfFailNewBuffer() throws Exception { doReturn(null).when(serializeRCFBufferPool).borrowObject(); when(trcfMapper.toState(any())).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointTrcfCheckoutBufferInvalidateFail() throws Exception { @@ -833,42 +830,52 @@ public void testToCheckpointTrcfCheckoutBufferInvalidateFail() throws Exception when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); doThrow(RuntimeException.class).when(serializeRCFBufferPool).invalidateObject(any()); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); - String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testFromEntityModelCheckpointWithTrcf() throws Exception { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + String model = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - Map entity = new HashMap<>(); - entity.put(FIELD_MODELV2, model); - entity.put(CommonName.TIMESTAMP, Instant.now().toString()); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + Map source = new HashMap<>(); + source.put(ADCheckpointDao.DETECTOR_ID, state.getConfigId()); + source.put(FIELD_MODELV2, model); + source.put(CommonName.TIMESTAMP, Instant.now().toString()); - assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); - assertTrue(entityModel.getTrcf().isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(source); + + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + + assertTrue(result != null); + assertTrue(result.getModel().isPresent()); } public void testFromEntityModelCheckpointTrcfMapperFail() throws Exception { setUpMockTrcf(); when(trcfMapper.toModel(any())).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + String model = checkpointDao.toCheckpoint(state.getModel().get(), modelId).get(); - Map entity = new HashMap<>(); - entity.put(FIELD_MODELV2, model); - entity.put(CommonName.TIMESTAMP, Instant.now().toString()); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + Map source = new HashMap<>(); + source.put(FIELD_MODELV2, model); + source.put(CommonName.TIMESTAMP, Instant.now().toString()); - assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); - assertFalse(entityModel.getTrcf().isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(source); + + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + + assertTrue(result != null); + assertTrue(result.getModel().isEmpty()); } private Pair, Instant> setUp1_0Model(String checkpointFileName) throws FileNotFoundException, @@ -896,20 +903,22 @@ public void testFromEntityModelCheckpointBWC() throws FileNotFoundException, IOE Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); Instant now = modelPair.getRight(); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); - Entry pair = result.get(); - assertEquals(now, pair.getValue()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); - EntityModel entityModel = pair.getKey(); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); + assertEquals(now, result.getLastCheckpointTime()); + + Deque samples = result.getSamples(); - Queue samples = entityModel.getSamples(); assertEquals(6, samples.size()); - double[] firstSample = samples.peek(); + double[] firstSample = samples.peek().getValueList(); assertEquals(1, firstSample.length); assertEquals(0.6832234717598454, firstSample[0], 1e-10); - ThresholdedRandomCutForest trcf = entityModel.getTrcf().get(); + ThresholdedRandomCutForest trcf = result.getModel().get(); RandomCutForest forest = trcf.getForest(); assertEquals(1, forest.getDimensions()); assertEquals(10, forest.getNumberOfTrees()); @@ -926,10 +935,9 @@ public void testFromEntityModelCheckpointBWC() throws FileNotFoundException, IOE public void testFromEntityModelCheckpointModelTooLarge() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, - indexName, gson, mapper, converter, @@ -940,43 +948,60 @@ public void testFromEntityModelCheckpointModelTooLarge() throws FileNotFoundExce 100_000, // checkpoint_2.json is of 224603 bytes. serializeRCFBufferPool, TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, - anomalyRate + anomalyRate, + clock ); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); // checkpoint is only configured to take in 1 MB checkpoint at most. But the checkpoint here is of 1408047 bytes. - assertTrue(!result.isPresent()); + assertTrue(result.getModel().isEmpty()); } // test no model is present in checkpoint public void testFromEntityModelCheckpointEmptyModel() throws FileNotFoundException, IOException, URISyntaxException { Map entity = new HashMap<>(); + entity.put(ADCheckpointDao.DETECTOR_ID, ADCheckpointDao.DETECTOR_ID); entity.put(CommonName.TIMESTAMP, Instant.now().toString()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(entity); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); - assertTrue(!result.isPresent()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result.getModel().isEmpty()); } public void testFromEntityModelCheckpointEmptySamples() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_1.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); - Queue samples = result.get().getKey().getSamples(); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); + Deque samples = result.getSamples(); assertEquals(0, samples.size()); } public void testFromEntityModelCheckpointNoRCF() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_3.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); - assertTrue(!result.get().getKey().getTrcf().isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); + assertTrue(result.getModel().isEmpty()); } public void testFromEntityModelCheckpointNoThreshold() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_4.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); - assertTrue(result.isPresent()); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(modelPair.getLeft()); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); + assertTrue(result != null); - ThresholdedRandomCutForest trcf = result.get().getKey().getTrcf().get(); + ThresholdedRandomCutForest trcf = result.getModel().get(); RandomCutForest forest = trcf.getForest(); assertEquals(1, forest.getDimensions()); assertEquals(10, forest.getNumberOfTrees()); @@ -984,19 +1009,20 @@ public void testFromEntityModelCheckpointNoThreshold() throws FileNotFoundExcept } public void testFromEntityModelCheckpointWithEntity() throws Exception { - ModelState state = MLUtil + ModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).entityAttributes(true).build()); Map content = checkpointDao.toIndexSource(state); // Opensearch will convert from java.time.ZonedDateTime to String. Here I am converting to simulate that content.put(CommonName.TIMESTAMP, "2021-09-23T05:00:37.93195Z"); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSource()).thenReturn(content); - Optional> result = checkpointDao.fromEntityModelCheckpoint(content, this.modelId); + ModelState result = checkpointDao + .processHCGetResponse(getResponse, this.modelId, ADCheckpointDao.DETECTOR_ID); - assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); - assertTrue(entityModel.getEntity().isPresent()); - assertEquals(state.getModel().getEntity().get(), entityModel.getEntity().get()); + assertTrue(result != null); + assertTrue(result.getEntity().isPresent()); + assertEquals(state.getEntity().get(), result.getEntity().get()); } private double[] getPoint(int dimensions, Random random) { @@ -1234,4 +1260,28 @@ public static String unescapeJavaString(String st) { } return sb.toString(); } + + public void testProcessEmptyCheckpoint() throws IOException { + String modelId = "abc"; + ModelState modelState = checkpointDao + .processHCGetResponse(TestHelpers.createBrokenGetResponse(modelId, "blah"), modelId, "123"); + assertEquals(Instant.MIN, modelState.getLastCheckpointTime()); + } + + public void testNonEmptyCheckpoint() throws IOException { + String modelId = "abc"; + String detectorId = "123"; + Pair> model = MLUtil.createNonEmptyModel(modelId); + ModelState inputModelState = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + Instant checkpointTime = Instant.ofEpochMilli(1000); + inputModelState.setLastCheckpointTime(checkpointTime); + + Map source = checkpointDao.toIndexSource(inputModelState); + ModelState modelState = checkpointDao + .processHCGetResponse(TestHelpers.createGetResponse(source, modelId, "blah"), modelId, "123"); + assertEquals(checkpointTime, modelState.getLastCheckpointTime()); + assertEquals(inputModelState.getSamples().size(), modelState.getSamples().size()); + assertEquals(now, modelState.getLastUsedTime()); + } } diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java index dcda1ff92..c9578ff50 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java @@ -17,6 +17,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.time.Clock; import java.util.Arrays; import java.util.Collections; import java.util.Locale; @@ -35,6 +36,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.ScrollableHitSource; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.ml.CheckpointDao; import org.opensearch.timeseries.util.ClientUtil; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; @@ -60,7 +62,7 @@ private enum DeleteExecutionMode { PARTIAL_FAILURE } - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; private Client client; private ClientUtil clientUtil; private Gson gson; @@ -77,12 +79,14 @@ private enum DeleteExecutionMode { double anomalyRate; + private Clock clock; + @SuppressWarnings("unchecked") @Override @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(CheckpointDao.class); + super.setUpLog4jForJUnit(ADCheckpointDao.class); client = mock(Client.class); clientUtil = mock(ClientUtil.class); @@ -97,10 +101,10 @@ public void setUp() throws Exception { objectPool = mock(GenericObjectPool.class); int deserializeRCFBufferSize = 512; anomalyRate = 0.005; - checkpointDao = new CheckpointDao( + clock = mock(Clock.class); + checkpointDao = new ADCheckpointDao( client, clientUtil, - ADCommonName.CHECKPOINT_INDEX_NAME, gson, mapper, converter, @@ -111,7 +115,8 @@ public void setUp() throws Exception { maxCheckpointBytes, objectPool, deserializeRCFBufferSize, - anomalyRate + anomalyRate, + clock ); } @@ -157,7 +162,7 @@ public void delete_by_detector_id_template(DeleteExecutionMode mode) { return null; }).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); - checkpointDao.deleteModelCheckpointByDetectorId(detectorId); + checkpointDao.deleteModelCheckpointByConfigId(detectorId); } public void testDeleteSingleNormal() throws Exception { @@ -172,7 +177,7 @@ public void testDeleteSingleIndexNotFound() throws Exception { public void testDeleteSingleResultFailure() throws Exception { delete_by_detector_id_template(DeleteExecutionMode.FAILURE); - assertTrue(testAppender.containsMessage(CheckpointDao.NOT_ABLE_TO_DELETE_LOG_MSG)); + assertTrue(testAppender.containsMessage(CheckpointDao.NOT_ABLE_TO_DELETE_CHECKPOINT_MSG)); } public void testDeleteSingleResultPartialFailure() throws Exception { diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 188146f69..00d11e71a 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -29,10 +29,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.Deque; import java.util.List; +import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.Queue; import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -43,8 +44,6 @@ import org.junit.BeforeClass; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; @@ -59,7 +58,13 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.amazon.randomcutforest.config.Precision; @@ -105,21 +110,48 @@ public void tearDown() throws Exception { // train using samples directly public void testTrainUsingSamples() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(numMinSamples); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); + Deque samples = MLUtil.createQueueSamples(numMinSamples); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); + assertTrue(modelState.getModel().isPresent()); + ThresholdedRandomCutForest ercf = modelState.getModel().get(); assertEquals(numMinSamples, ercf.getForest().getTotalUpdates()); checkSemaphoreRelease(); } public void testColdStart() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + ModelState model = new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -142,25 +174,46 @@ public void testColdStart() throws InterruptedException, IOException { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + CountDownLatch latch = new CountDownLatch(1); + ActionListener>> listener = ActionListener.wrap(response -> { latch.countDown(); }, exception -> { + assertFalse("should not reach here", true); + }); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); + assertTrue(model.getModel().isPresent()); + ThresholdedRandomCutForest ercf = model.getModel().get(); // 1 round: stride * (samples - 1) + 1 = 60 * 2 + 1 = 121 // plus 1 existing sample assertEquals(121, ercf.getForest().getTotalUpdates()); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + assertTrue("size: " + modelState.getSamples().size(), model.getSamples().isEmpty()); checkSemaphoreRelease(); released.set(false); // too frequent cold start of the same detector will fail samples = MLUtil.createQueueSamples(1); - model = new EntityModel(entity, samples, null); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); - assertFalse(model.getTrcf().isPresent()); + assertFalse(model.getModel().isPresent()); // the samples is not touched since cold start does not happen assertEquals("size: " + model.getSamples().size(), 1, model.getSamples().size()); checkSemaphoreRelease(); @@ -181,9 +234,18 @@ public void testColdStart() throws InterruptedException, IOException { // min max: miss one public void testMissMin() throws IOException, InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -191,11 +253,20 @@ public void testMissMin() throws IOException, InterruptedException { return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); verify(searchFeatureDao, never()).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - assertTrue(!model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isEmpty()); checkSemaphoreRelease(); } @@ -204,7 +275,7 @@ public void testMissMin() throws IOException, InterruptedException { * @param modelState an initialized model state * @param coldStartData cold start data that initialized the modelState */ - private void diffTesting(ModelState modelState, List coldStartData) { + private void diffTesting(ModelState modelState, List coldStartData) { int inputDimension = detector.getEnabledFeatureIds().size(); ThresholdedRandomCutForest refTRcf = ThresholdedRandomCutForest @@ -243,8 +314,8 @@ private void diffTesting(ModelState modelState, List cold for (int i = 0; i < 100; i++) { double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); AnomalyDescriptor descriptor = refTRcf.process(point, 0); - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + Sample sample = new Sample(point, Instant.now(), Instant.now()); + ThresholdingResult result = modelManager.getResult(sample, modelState, modelId, Optional.of(entity), detector, "123"); assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); } @@ -268,10 +339,19 @@ private List convertToFeatures(double[][] interval, int numValsToKeep) // two segments of samples, one segment has 3 samples, while another one has only 1 public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - double[] savedSample = samples.peek(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + double[] savedSample = samples.peek().getValueList(); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -295,13 +375,22 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); // 1 round: stride * (samples - 1) + 1 = 60 * 4 + 1 = 241 // if 241 < shingle size + numMinSamples, then another round is performed - assertEquals(241, modelState.getModel().getTrcf().get().getForest().getTotalUpdates()); + assertEquals(241, modelState.getModel().get().getForest().getTotalUpdates()); checkSemaphoreRelease(); List expectedColdStartData = new ArrayList<>(); @@ -315,17 +404,25 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc expectedColdStartData.addAll(convertToFeatures(interval2, 60)); double[][] interval3 = imputer.impute(new double[][] { new double[] { sample3[0], sample5[0] } }, 121); expectedColdStartData.addAll(convertToFeatures(interval3, 121)); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + assertTrue("size: " + modelState.getSamples().size(), modelState.getSamples().isEmpty()); assertEquals(241, expectedColdStartData.size()); diffTesting(modelState, expectedColdStartData); } // two segments of samples, one segment has 3 samples, while another one 2 samples public void testTwoSegments() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - double[] savedSample = samples.peek(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -351,11 +448,20 @@ public void testTwoSegments() throws InterruptedException, IOException { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); + assertTrue(modelState.getModel().isPresent()); + ThresholdedRandomCutForest ercf = modelState.getModel().get(); // 1 rounds: stride * (samples - 1) + 1 = 60 * 5 + 1 = 301 assertEquals(301, ercf.getForest().getTotalUpdates()); checkSemaphoreRelease(); @@ -374,14 +480,23 @@ public void testTwoSegments() throws InterruptedException, IOException { double[][] interval4 = imputer.impute(new double[][] { new double[] { sample5[0], sample6[0] } }, 61); expectedColdStartData.addAll(convertToFeatures(interval4, 61)); assertEquals(301, expectedColdStartData.size()); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + assertTrue("size: " + modelState.getSamples().size(), modelState.getSamples().isEmpty()); diffTesting(modelState, expectedColdStartData); } public void testThrottledColdStart() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -389,9 +504,18 @@ public void testThrottledColdStart() throws InterruptedException { return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); - entityColdStarter.trainModel(entity, "456", modelState, listener); + entityColdStarter.trainModel(featureRequest, "456", modelState, listener); // only the first one makes the call verify(searchFeatureDao, times(1)).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); @@ -399,9 +523,18 @@ public void testThrottledColdStart() throws InterruptedException { } public void testColdStartException() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -409,7 +542,16 @@ public void testColdStartException() throws InterruptedException { return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); assertTrue(stateManager.fetchExceptionAndClear(detectorId).isPresent()); checkSemaphoreRelease(); @@ -417,9 +559,18 @@ public void testColdStartException() throws InterruptedException { @SuppressWarnings("unchecked") public void testNotEnoughSamples() throws InterruptedException, IOException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); detector = TestHelpers.AnomalyDetectorBuilder .newInstance() @@ -449,17 +600,26 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(!model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isEmpty()); // 1st round we add 57 and 1. // 2nd round we add 57 and 1. - Queue currentSamples = model.getSamples(); + Deque currentSamples = modelState.getSamples(); assertEquals("real sample size is " + currentSamples.size(), 4, currentSamples.size()); int j = 0; while (!currentSamples.isEmpty()) { - double[] element = currentSamples.poll(); + double[] element = currentSamples.poll().getValueList(); assertEquals(1, element.length); if (j == 0 || j == 2) { assertEquals(57, element[0], 1e-10); @@ -472,9 +632,18 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { @SuppressWarnings("unchecked") public void testEmptyDataRange() throws InterruptedException { - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); // the min-max range 894056973000L~894057860000L is too small and thus no data range can be found when(clock.millis()).thenReturn(894057860000L); @@ -493,12 +662,21 @@ public void testEmptyDataRange() throws InterruptedException { return null; }).when(searchFeatureDao).getMinDataTime(any(), any(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + samples.peek().getValueList(), + samples.peek().getDataStartTime().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(!model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isEmpty()); // the min-max range is too small and thus no data range can be found - assertEquals("real sample size is " + model.getSamples().size(), 1, model.getSamples().size()); + assertEquals("real sample size is " + modelState.getSamples().size(), 1, modelState.getSamples().size()); } public void testTrainModelFromExistingSamplesEnoughSamples() { @@ -524,12 +702,21 @@ public void testTrainModelFromExistingSamplesEnoughSamples() { .transformMethod(TransformMethod.NORMALIZE) .alertOnce(true) .autoAdjust(true); - Tuple, ThresholdedRandomCutForest> models = MLUtil.prepareModel(inputDimension, rcfConfig); - Queue samples = models.v1(); + Tuple, ThresholdedRandomCutForest> models = MLUtil.prepareModel(inputDimension, rcfConfig); + Deque samples = models.v1(); ThresholdedRandomCutForest rcf = models.v2(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); Random r = new Random(); @@ -537,19 +724,28 @@ public void testTrainModelFromExistingSamplesEnoughSamples() { for (int i = 0; i < 100; i++) { double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); AnomalyDescriptor descriptor = rcf.process(point, 0); - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + Sample sample = new Sample(point, Instant.now(), Instant.now()); + ThresholdingResult result = modelManager.getResult(sample, modelState, modelId, Optional.of(entity), detector, "123"); assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); } } public void testTrainModelFromExistingSamplesNotEnoughSamples() { - Queue samples = new ArrayDeque<>(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); - entityColdStarter.trainModelFromExistingSamples(modelState, detector.getShingleSize()); - assertTrue(!modelState.getModel().getTrcf().isPresent()); + Deque samples = new ArrayDeque<>(); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); + entityColdStarter.trainModelFromExistingSamples(modelState, Optional.of(entity), detector, "123"); + assertTrue(modelState.getModel().isEmpty()); } @SuppressWarnings("unchecked") @@ -630,8 +826,17 @@ public int compare(Entry p1, Entry p2) { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); - modelState = new ModelState<>(model, modelId, detector.getId(), ModelType.ENTITY.getName(), clock, priority); + modelState = new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + new ArrayDeque<>() + ); released = new AtomicBoolean(); @@ -641,10 +846,19 @@ public int compare(Entry p1, Entry p2) { inProgressLatch.countDown(); }); - entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + Instant.now().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detector.getId(), modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); int tp = 0; int fp = 0; @@ -652,8 +866,8 @@ public int compare(Entry p1, Entry p2) { long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; for (int j = trainTestSplit; j < data.length; j++) { - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + Sample sample = new Sample(data[j], Instant.now(), Instant.now()); + ThresholdingResult result = modelManager.getResult(sample, modelState, modelId, Optional.of(entity), detector, "123"); if (result.getGrade() > 0) { if (changeTimestamps[j] == 0) { fp++; @@ -701,7 +915,7 @@ public void testAccuracyThirteenMinuteInterval() throws Exception { public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); // for one minute interval, we need to disable interpolation to achieve good results - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStart( clock, threadPool, stateManager, @@ -711,19 +925,18 @@ public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, rcfSeed, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 60 ); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), TimeSeriesSettings.NUM_TREES, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, @@ -743,20 +956,29 @@ public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { accuracyTemplate(1, 0.6f, 0.6f); } - private ModelState createStateForCacheRelease() { + private ModelState createStateForCacheRelease() { inProgressLatch = new CountDownLatch(1); releaseSemaphore = () -> { released.set(true); inProgressLatch.countDown(); }; listener = ActionListener.wrap(releaseSemaphore); - Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + Deque samples = MLUtil.createQueueSamples(1); + return new ModelState( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + samples + ); } public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedException { - ModelState modelState = createStateForCacheRelease(); + ModelState modelState = createStateForCacheRelease(); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); listener.onResponse(Optional.of(1602269260000L)); @@ -778,15 +1000,24 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + Instant.now().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); modelState = createStateForCacheRelease(); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); // model is not trained as the door keeper remembers it and won't retry training - assertTrue(!modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isEmpty()); // make sure when the next maintenance coming, current door keeper gets reset // note our detector interval is 1 minute and the door keeper will expire in 60 intervals, which are 60 minutes @@ -794,14 +1025,23 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx entityColdStarter.maintenance(); modelState = createStateForCacheRelease(); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + Instant.now().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); // model is trained as the door keeper gets reset - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); } public void testCacheReleaseAfterClear() throws IOException, InterruptedException { - ModelState modelState = createStateForCacheRelease(); + ModelState modelState = createStateForCacheRelease(); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); listener.onResponse(Optional.of(1602269260000L)); @@ -823,16 +1063,25 @@ public void testCacheReleaseAfterClear() throws IOException, InterruptedExceptio return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + FeatureRequest featureRequest = new FeatureRequest( + Instant.now().toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + new double[] { 1.3 }, + Instant.now().toEpochMilli(), + entity, + "123" + ); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); entityColdStarter.clear(detectorId); modelState = createStateForCacheRelease(); - entityColdStarter.trainModel(entity, detectorId, modelState, listener); + entityColdStarter.trainModel(featureRequest, detectorId, modelState, listener); checkSemaphoreRelease(); // model is trained as the door keeper is regenerated after clearance - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityModelTests.java b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java index 1f4afe829..365f609a9 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityModelTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java @@ -11,11 +11,14 @@ package org.opensearch.ad.ml; +import java.time.Instant; import java.util.ArrayDeque; import org.junit.Before; import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; @@ -29,45 +32,45 @@ public void setup() { } public void testNullInternalSampleQueue() { - EntityModel model = new EntityModel(null, null, null); - model.addSample(new double[] { 0.8 }); + ModelState model = new ModelState<>(null, null, null, null, null, 0, null, null, null); + model.addSample(new Sample(new double[] { 0.8 }, Instant.now(), Instant.now())); assertEquals(1, model.getSamples().size()); } public void testNullInputSample() { - EntityModel model = new EntityModel(null, null, null); + ModelState model = new ModelState<>(null, null, null, null, null, 0, null, null, null); model.addSample(null); assertEquals(0, model.getSamples().size()); } public void testEmptyInputSample() { - EntityModel model = new EntityModel(null, null, null); - model.addSample(new double[] {}); + ModelState model = new ModelState<>(null, null, null, null, null, 0, null, null, null); + model.addSample(new Sample(new double[] {}, Instant.now(), Instant.now())); assertEquals(0, model.getSamples().size()); } @Test public void trcf_constructor() { - EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); - assertEquals(trcf, em.getTrcf().get()); + ModelState em = new ModelState<>(trcf, null, null, null, null, 0, null, null, new ArrayDeque<>()); + assertEquals(trcf, em.getModel().get()); } @Test public void clear() { - EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); + ModelState em = new ModelState<>(trcf, null, null, null, null, 0, null, null, new ArrayDeque<>()); em.clear(); assertTrue(em.getSamples().isEmpty()); - assertFalse(em.getTrcf().isPresent()); + assertFalse(em.getModel().isPresent()); } @Test public void setTrcf() { - EntityModel em = new EntityModel(null, null, null); - assertFalse(em.getTrcf().isPresent()); + ModelState em = new ModelState<>(null, null, null, null, null, 0, null, null, null); + assertFalse(em.getModel().isPresent()); - em.setTrcf(this.trcf); - assertTrue(em.getTrcf().isPresent()); + em.setModel(this.trcf); + assertTrue(em.getModel().isPresent()); } } diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java index bf2732777..51cf7f16d 100644 --- a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -20,6 +20,7 @@ import static org.mockito.Mockito.when; import java.time.Clock; +import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayDeque; import java.util.ArrayList; @@ -34,8 +35,6 @@ import org.apache.lucene.tests.util.TimeUnits; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -46,11 +45,18 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; import com.google.common.collect.ImmutableList; @@ -118,10 +124,8 @@ private void averageAccuracyTemplate( searchFeatureDao, imputer, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -131,7 +135,7 @@ private void averageAccuracyTemplate( TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStart( clock, threadPool, stateManager, @@ -141,19 +145,18 @@ private void averageAccuracyTemplate( numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, seed, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 ); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), TimeSeriesSettings.NUM_TREES, TimeSeriesSettings.NUM_SAMPLES_PER_TREE, @@ -218,14 +221,16 @@ public int compare(Entry p1, Entry p2) { }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), eq(AnalysisType.AD), any()); entity = Entity.createSingleAttributeEntity("field", entityName + z); - EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); - ModelState modelState = new ModelState<>( - model, + ModelState modelState = new ModelState<>( + null, entity.getModelId(detectorId).get(), detector.getId(), - ModelType.ENTITY.getName(), + ModelManager.ModelType.TRCF.getName(), clock, - priority + priority, + null, + Optional.of(entity), + new ArrayDeque<>() ); released = new AtomicBoolean(); @@ -236,10 +241,25 @@ public int compare(Entry p1, Entry p2) { inProgressLatch.countDown(); }); - entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); + long dataStartTimeMs = System.currentTimeMillis(); + entityColdStarter + .trainModel( + new FeatureRequest( + dataStartTimeMs + 60000, + detector.getId(), + RequestPriority.MEDIUM, + new double[] {}, + dataStartTimeMs, + entity, + null + ), + detector.getId(), + modelState, + listener + ); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(modelState.getModel().isPresent()); int tp = 0; int fp = 0; @@ -248,7 +268,7 @@ public int compare(Entry p1, Entry p2) { for (int j = trainTestSplit; j < data.length; j++) { ThresholdingResult result = modelManager - .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + .getResult(new Sample(data[j], Instant.now(), Instant.now()), modelState, modelId, Optional.of(entity), detector, null); if (result.getGrade() > 0) { if (changeTimestamps[j] == 0) { fp++; diff --git a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java index cbb7b09ba..116f7adf2 100644 --- a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java @@ -51,11 +51,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -71,7 +68,11 @@ import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -91,7 +92,7 @@ @SuppressWarnings("unchecked") public class ModelManagerTests { - private ModelManager modelManager; + private ADModelManager modelManager; @Mock private AnomalyDetector anomalyDetector; @@ -103,7 +104,7 @@ public class ModelManagerTests { private JvmService jvmService; @Mock - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; @Mock private Clock clock; @@ -112,16 +113,10 @@ public class ModelManagerTests { private FeatureManager featureManager; @Mock - private EntityColdStarter entityColdStarter; + private ADEntityColdStart entityColdStarter; @Mock - private EntityCache cache; - - @Mock - private ModelState modelState; - - @Mock - private EntityModel entityModel; + private ModelState modelState; @Mock private ThresholdedRandomCutForest trcf; @@ -225,7 +220,7 @@ public void setup() { .build(); modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, @@ -248,8 +243,7 @@ public void setup() { rcfModelId = "detectorId_model_rcf_1"; thresholdModelId = "detectorId_model_threshold"; - when(this.modelState.getModel()).thenReturn(this.entityModel); - when(this.entityModel.getTrcf()).thenReturn(Optional.of(this.trcf)); + when(this.modelState.getModel()).thenReturn(Optional.of(this.trcf)); when(anomalyDetector.getShingleSize()).thenReturn(shingleSize); } @@ -267,7 +261,7 @@ private Object[] getDetectorIdForModelIdData() { @Test @Parameters(method = "getDetectorIdForModelIdData") public void getDetectorIdForModelId_returnExpectedId(String modelId, String expectedDetectorId) { - assertEquals(expectedDetectorId, SingleStreamModelIdMapper.getDetectorIdForModelId(modelId)); + assertEquals(expectedDetectorId, SingleStreamModelIdMapper.getConfigIdForModelId(modelId)); } private Object[] getDetectorIdForModelIdIllegalArgument() { @@ -277,7 +271,7 @@ private Object[] getDetectorIdForModelIdIllegalArgument() { @Test(expected = IllegalArgumentException.class) @Parameters(method = "getDetectorIdForModelIdIllegalArgument") public void getDetectorIdForModelId_throwIllegalArgument_forInvalidId(String modelId) { - SingleStreamModelIdMapper.getDetectorIdForModelId(modelId); + SingleStreamModelIdMapper.getConfigIdForModelId(modelId); } private Map createDataNodes(int numDataNodes) { @@ -415,7 +409,7 @@ public void getRcfResult_throwToListener_whenHeapLimitExceed() { // use new memoryTracker modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, @@ -868,34 +862,12 @@ public void getPreviewResults_throwIllegalArgument_forInvalidInput() { modelManager.getPreviewResults(new double[0][0], shingleSize); } - @Test - public void processEmptyCheckpoint() { - ModelState modelState = modelManager.processEntityCheckpoint(Optional.empty(), null, "", "", shingleSize); - assertEquals(Instant.MIN, modelState.getLastCheckpointTime()); - } - - @Test - public void processNonEmptyCheckpoint() { - String modelId = "abc"; - String detectorId = "123"; - EntityModel model = MLUtil.createNonEmptyModel(modelId); - Instant checkpointTime = Instant.ofEpochMilli(1000); - ModelState modelState = modelManager - .processEntityCheckpoint( - Optional.of(new SimpleImmutableEntry<>(model, checkpointTime)), - null, - modelId, - detectorId, - shingleSize - ); - assertEquals(checkpointTime, modelState.getLastCheckpointTime()); - assertEquals(model.getSamples().size(), modelState.getModel().getSamples().size()); - assertEquals(now, modelState.getLastUsedTime()); - } - @Test public void getNullState() { - assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getAnomalyResultForEntity(new double[] {}, null, "", null, shingleSize)); + assertEquals( + new ThresholdingResult(0, 0, 0), + modelManager.getResult(new Sample(new double[] {}, Instant.now(), Instant.now()), null, "", null, anomalyDetector, "") + ); } @Test @@ -909,10 +881,8 @@ public void getEmptyStateFullSamples() { searchFeatureDao, interpolator, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, - AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, - AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + TimeSeriesSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + TimeSeriesSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, @@ -922,9 +892,9 @@ public void getEmptyStateFullSamples() { TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - CheckpointWriteWorker checkpointWriteQueue = mock(CheckpointWriteWorker.class); + ADCheckpointWriteWorker checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStart( clock, threadPool, stateManager, @@ -934,18 +904,17 @@ public void getEmptyStateFullSamples() { numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - interpolator, searchFeatureDao, TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - TimeSeriesSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + 1 ); modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, @@ -964,50 +933,60 @@ public void getEmptyStateFullSamples() { ) ); - ModelState state = MLUtil + ModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples).build()); - EntityModel model = state.getModel(); - assertTrue(!model.getTrcf().isPresent()); - ThresholdingResult result = modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); + Optional model = state.getModel(); + assertTrue(model.isEmpty()); + ThresholdingResult result = modelManager + .getResult(new Sample(new double[] { -1 }, Instant.now(), Instant.now()), state, "", Optional.empty(), anomalyDetector, ""); // model outputs scores assertTrue(result.getRcfScore() != 0); // added the sample to score since our model is empty - assertEquals(0, model.getSamples().size()); + assertEquals(0, state.getSamples().size()); } @Test public void getAnomalyResultForEntityNoModel() { - ModelState modelState = new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + ModelState modelState = new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + null + ); ThresholdingResult result = modelManager - .getAnomalyResultForEntity( - new double[] { -1 }, + .getResult( + new Sample(new double[] { -1 }, Instant.now(), Instant.now()), modelState, modelId, - Entity.createSingleAttributeEntity("field", "val"), - shingleSize + Optional.of(Entity.createSingleAttributeEntity("field", "val")), + anomalyDetector, + "" ); // model outputs scores assertEquals(new ThresholdingResult(0, 0, 0), result); // added the sample to score since our model is empty - assertEquals(1, modelState.getModel().getSamples().size()); + assertEquals(1, modelState.getSamples().size()); } @Test public void getEmptyStateNotFullSamples() { - ModelState state = MLUtil + ModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples - 1).build()); assertEquals( new ThresholdingResult(0, 0, 0), - modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize) + modelManager.getResult(new Sample(new double[] { -1 }, Instant.now(), Instant.now()), state, "", null, anomalyDetector, "") ); - assertEquals(numMinSamples, state.getModel().getSamples().size()); + assertEquals(numMinSamples, state.getSamples().size()); } @Test public void scoreSamples() { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); - assertEquals(0, state.getModel().getSamples().size()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + modelManager.getResult(new Sample(new double[] { -1 }, Instant.now(), Instant.now()), state, "", null, anomalyDetector, ""); + assertEquals(0, state.getSamples().size()); assertEquals(now, state.getLastUsedTime()); } @@ -1019,7 +998,7 @@ public void getAnomalyResultForEntity_withTrcf() { when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); ThresholdingResult result = modelManager - .getAnomalyResultForEntity(this.point, this.modelState, this.detectorId, null, this.shingleSize); + .getResult(new Sample(this.point, Instant.now(), Instant.now()), this.modelState, this.detectorId, null, anomalyDetector, ""); assertEquals( new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), @@ -1042,9 +1021,11 @@ public void score_with_trcf() { when(rcf.getDimensions()).thenReturn(40); when(this.trcf.getForest()).thenReturn(rcf); when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); - when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); + when(this.modelState.getSamples()) + .thenReturn(new ArrayDeque<>(Arrays.asList(new Sample(this.point, Instant.now(), Instant.now())))); - ThresholdingResult result = modelManager.score(this.point, this.detectorId, this.modelState); + ThresholdingResult result = modelManager + .score(new Sample(this.point, Instant.now(), Instant.now()), this.modelId, this.modelState, anomalyDetector); assertEquals( new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), @@ -1075,7 +1056,8 @@ public void score_throw() { when(rcf.getDimensions()).thenReturn(40); when(this.trcf.getForest()).thenReturn(rcf); doThrow(new IllegalArgumentException()).when(trcf).process(any(), anyLong()); - when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); - modelManager.score(this.point, this.detectorId, this.modelState); + when(this.modelState.getSamples()) + .thenReturn(new ArrayDeque<>(Arrays.asList(new Sample(this.point, Instant.now(), Instant.now())))); + modelManager.score(new Sample(this.point, Instant.now(), Instant.now()), this.modelId, this.modelState, anomalyDetector); } } diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java index 327e3bf51..9a58b9c4f 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java @@ -12,16 +12,16 @@ package org.opensearch.ad.mock.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.JobResponse; -public class MockAnomalyDetectorJobAction extends ActionType { +public class MockAnomalyDetectorJobAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/mockjobmanagement"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/mockjobmanagement"; public static final MockAnomalyDetectorJobAction INSTANCE = new MockAnomalyDetectorJobAction(); private MockAnomalyDetectorJobAction() { - super(NAME, AnomalyDetectorJobResponse::new); + super(NAME, JobResponse::new); } } diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java index bf339161b..3adeead1c 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java @@ -22,10 +22,8 @@ import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -37,12 +35,14 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; -public class MockAnomalyDetectorJobTransportActionWithUser extends - HandledTransportAction { +public class MockAnomalyDetectorJobTransportActionWithUser extends HandledTransportAction { private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); private final Client client; @@ -55,6 +55,7 @@ public class MockAnomalyDetectorJobTransportActionWithUser extends private final ADTaskManager adTaskManager; private final TransportService transportService; private final ExecuteADResultResponseRecorder recorder; + private final NodeStateManager nodeStateManager; @Inject public MockAnomalyDetectorJobTransportActionWithUser( @@ -66,9 +67,10 @@ public MockAnomalyDetectorJobTransportActionWithUser( ADIndexManagement anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder + ExecuteADResultResponseRecorder recorder, + NodeStateManager nodeStateManager ) { - super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); + super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, JobRequest::new); this.transportService = transportService; this.client = client; this.clusterService = clusterService; @@ -82,15 +84,14 @@ public MockAnomalyDetectorJobTransportActionWithUser( ThreadContext threadContext = new ThreadContext(settings); context = threadContext.stashContext(); this.recorder = recorder; + this.nodeStateManager = nodeStateManager; } @Override - protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener listener) { - String detectorId = request.getDetectorID(); - DateRange detectionDateRange = request.getDetectionDateRange(); + protected void doExecute(Task task, JobRequest request, ActionListener listener) { + String detectorId = request.getConfigID(); + DateRange detectionDateRange = request.getDateRange(); boolean historical = request.isHistorical(); - long seqNo = request.getSeqNo(); - long primaryTerm = request.getPrimaryTerm(); String rawPath = request.getRawPath(); TimeValue requestTimeout = AD_REQUEST_TIMEOUT.get(settings); String userStr = "user_name|backendrole1,backendrole2|roles1,role2"; @@ -102,17 +103,7 @@ protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionLis detectorId, filterByEnabled, listener, - (anomalyDetector) -> executeDetector( - listener, - detectorId, - seqNo, - primaryTerm, - rawPath, - requestTimeout, - user, - detectionDateRange, - historical - ), + (anomalyDetector) -> executeDetector(listener, detectorId, rawPath, requestTimeout, user, detectionDateRange, historical), client, clusterService, xContentRegistry, @@ -125,33 +116,28 @@ protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionLis } private void executeDetector( - ActionListener listener, + ActionListener listener, String detectorId, - long seqNo, - long primaryTerm, String rawPath, TimeValue requestTimeout, User user, DateRange detectionDateRange, boolean historical ) { - IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + ADIndexJobActionHandler handler = new ADIndexJobActionHandler( client, anomalyDetectionIndices, - detectorId, - seqNo, - primaryTerm, - requestTimeout, xContentRegistry, - transportService, adTaskManager, - recorder + recorder, + nodeStateManager, + Settings.EMPTY ); if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { - adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); + handler.startConfig(detectorId, detectionDateRange, user, transportService, context, listener); } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { // Stop detector - adTaskManager.stopDetector(detectorId, historical, handler, user, transportService, listener); + handler.stopConfig(detectorId, historical, user, transportService, listener); } } } diff --git a/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java index 27456589a..324fc0373 100644 --- a/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java @@ -19,6 +19,7 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityTaskProfile; public class ADEntityTaskProfileTests extends OpenSearchSingleNodeTestCase { @@ -32,9 +33,9 @@ protected NamedWriteableRegistry writableRegistry() { return getInstanceFromNode(NamedWriteableRegistry.class); } - private ADEntityTaskProfile createADEntityTaskProfile() { + private EntityTaskProfile createADEntityTaskProfile() { Entity entity = createEntityAndAttributes(); - return new ADEntityTaskProfile(1, 23L, false, 1, 2L, "1234", entity, "4321", ADTaskType.HISTORICAL_HC_ENTITY.name()); + return new EntityTaskProfile(1, 23L, false, 1, 2L, "1234", entity, "4321", ADTaskType.HISTORICAL_HC_ENTITY.name()); } private Entity createEntityAndAttributes() { @@ -49,24 +50,24 @@ private Entity createEntityAndAttributes() { } public void testADEntityTaskProfileSerialization() throws IOException { - ADEntityTaskProfile entityTask = createADEntityTaskProfile(); + EntityTaskProfile entityTask = createADEntityTaskProfile(); BytesStreamOutput output = new BytesStreamOutput(); entityTask.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); - ADEntityTaskProfile parsedEntityTask = new ADEntityTaskProfile(input); + EntityTaskProfile parsedEntityTask = new EntityTaskProfile(input); assertEquals(entityTask, parsedEntityTask); } public void testParseADEntityTaskProfile() throws IOException { - ADEntityTaskProfile entityTask = createADEntityTaskProfile(); + EntityTaskProfile entityTask = createADEntityTaskProfile(); String adEntityTaskProfileString = TestHelpers .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + EntityTaskProfile parsedEntityTask = EntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); assertEquals(entityTask, parsedEntityTask); } public void testParseADEntityTaskProfileWithNullEntity() throws IOException { - ADEntityTaskProfile entityTask = new ADEntityTaskProfile( + EntityTaskProfile entityTask = new EntityTaskProfile( 1, 23L, false, @@ -82,14 +83,14 @@ public void testParseADEntityTaskProfileWithNullEntity() throws IOException { assertNull(entityTask.getEntity()); String adEntityTaskProfileString = TestHelpers .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + EntityTaskProfile parsedEntityTask = EntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); assertEquals(entityTask, parsedEntityTask); } public void testADEntityTaskProfileEqual() { - ADEntityTaskProfile entityTaskOne = createADEntityTaskProfile(); - ADEntityTaskProfile entityTaskTwo = createADEntityTaskProfile(); - ADEntityTaskProfile entityTaskThree = new ADEntityTaskProfile( + EntityTaskProfile entityTaskOne = createADEntityTaskProfile(); + EntityTaskProfile entityTaskTwo = createADEntityTaskProfile(); + EntityTaskProfile entityTaskThree = new EntityTaskProfile( null, null, false, @@ -106,7 +107,7 @@ public void testADEntityTaskProfileEqual() { public void testParseADEntityTaskProfileWithMultipleNullFields() throws IOException { Entity entity = createEntityAndAttributes(); - ADEntityTaskProfile entityTask = new ADEntityTaskProfile( + EntityTaskProfile entityTask = new EntityTaskProfile( null, null, false, @@ -119,7 +120,7 @@ public void testParseADEntityTaskProfileWithMultipleNullFields() throws IOExcept ); String adEntityTaskProfileString = TestHelpers .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + EntityTaskProfile parsedEntityTask = EntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); assertEquals(entityTask, parsedEntityTask); } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index d3298eae2..56820dc32 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -11,10 +11,8 @@ package org.opensearch.ad.model; -import static org.opensearch.ad.constant.ADCommonMessages.INVALID_RESULT_INDEX_PREFIX; import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; import static org.opensearch.ad.model.AnomalyDetector.MAX_RESULT_INDEX_NAME_SIZE; -import static org.opensearch.timeseries.constant.CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME; import java.io.IOException; import java.time.Instant; @@ -30,6 +28,8 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -316,7 +316,8 @@ public void testInvalidShingleSize() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); } @@ -343,7 +344,8 @@ public void testNullDetectorName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); } @@ -370,7 +372,8 @@ public void testBlankDetectorName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); } @@ -397,7 +400,8 @@ public void testNullTimeField() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); } @@ -424,7 +428,8 @@ public void testNullIndices() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); } @@ -451,7 +456,8 @@ public void testEmptyIndices() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); } @@ -478,7 +484,8 @@ public void testNullDetectionInterval() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); } @@ -504,7 +511,8 @@ public void testInvalidDetectionInterval() { null, null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); assertEquals("Detection interval must be a positive integer", exception.getMessage()); @@ -531,7 +539,8 @@ public void testInvalidWindowDelay() { null, null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ) ); assertEquals("Interval -1 should be non-negative", exception.getMessage()); @@ -553,7 +562,7 @@ public void testEmptyFeatures() throws IOException { } public void testGetShingleSize() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), @@ -571,13 +580,14 @@ public void testGetShingleSize() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); assertEquals((int) anomalyDetector.getShingleSize(), 5); } public void testGetShingleSizeReturnsDefaultValue() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), @@ -595,13 +605,14 @@ public void testGetShingleSizeReturnsDefaultValue() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); assertEquals((int) anomalyDetector.getShingleSize(), TimeSeriesSettings.DEFAULT_SHINGLE_SIZE); } public void testNullFeatureAttributes() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), @@ -619,21 +630,22 @@ public void testNullFeatureAttributes() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); assertNotNull(anomalyDetector.getFeatureAttributes()); assertEquals(0, anomalyDetector.getFeatureAttributes().size()); } public void testValidateResultIndex() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new AnomalyDetector( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), + null, TestHelpers.randomQuery(), TestHelpers.randomIntervalTimeConfiguration(), TestHelpers.randomIntervalTimeConfiguration(), @@ -644,11 +656,11 @@ public void testValidateResultIndex() throws IOException { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); - String errorMessage = anomalyDetector.validateCustomResultIndex("abc"); - assertEquals(INVALID_RESULT_INDEX_PREFIX, errorMessage); + assertEquals(ADCommonMessages.INVALID_RESULT_INDEX_PREFIX, errorMessage); StringBuilder resultIndexNameBuilder = new StringBuilder(CUSTOM_RESULT_INDEX_PREFIX); for (int i = 0; i < MAX_RESULT_INDEX_NAME_SIZE - CUSTOM_RESULT_INDEX_PREFIX.length(); i++) { @@ -658,10 +670,10 @@ public void testValidateResultIndex() throws IOException { resultIndexNameBuilder.append("a"); errorMessage = anomalyDetector.validateCustomResultIndex(resultIndexNameBuilder.toString()); - assertEquals(AnomalyDetector.INVALID_RESULT_INDEX_NAME_SIZE, errorMessage); + assertEquals(Config.INVALID_RESULT_INDEX_NAME_SIZE, errorMessage); errorMessage = anomalyDetector.validateCustomResultIndex(CUSTOM_RESULT_INDEX_PREFIX + "abc#"); - assertEquals(INVALID_CHAR_IN_RESULT_INDEX_NAME, errorMessage); + assertEquals(CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME, errorMessage); } public void testParseAnomalyDetectorWithNoDescription() throws IOException { diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java index 424de19da..28245aa31 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java @@ -11,8 +11,6 @@ package org.opensearch.ad.model; -import static org.opensearch.test.OpenSearchTestCase.randomDouble; - import java.io.IOException; import java.util.Collection; import java.util.Locale; diff --git a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java index 9960a5fe2..07c7410b4 100644 --- a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java @@ -21,13 +21,19 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigProfile; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.model.ProfileName; public class DetectorProfileTests extends OpenSearchTestCase { - private DetectorProfile createRandomDetectorProfile() { + private ConfigProfile createRandomDetectorProfile() { return new DetectorProfile.Builder() - .state(DetectorState.INIT) + .state(ConfigState.INIT) .error(randomAlphaOfLength(5)) .modelProfile( new ModelProfileOnNode[] { @@ -45,7 +51,7 @@ private DetectorProfile createRandomDetectorProfile() { .totalSizeInBytes(-1) .totalEntities(randomLong()) .activeEntities(randomLong()) - .adTaskProfile( + .taskProfile( new ADTaskProfile( randomAlphaOfLength(5), randomInt(), @@ -60,17 +66,17 @@ private DetectorProfile createRandomDetectorProfile() { } public void testParseDetectorProfile() throws IOException { - DetectorProfile detectorProfile = createRandomDetectorProfile(); + ConfigProfile detectorProfile = createRandomDetectorProfile(); BytesStreamOutput output = new BytesStreamOutput(); detectorProfile.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); - DetectorProfile parsedDetectorProfile = new DetectorProfile(input); + ConfigProfile parsedDetectorProfile = new DetectorProfile(input); assertEquals("Detector profile serialization doesn't work", detectorProfile, parsedDetectorProfile); } public void testMergeDetectorProfile() { - DetectorProfile detectorProfileOne = createRandomDetectorProfile(); - DetectorProfile detectorProfileTwo = createRandomDetectorProfile(); + ConfigProfile detectorProfileOne = createRandomDetectorProfile(); + ConfigProfile detectorProfileTwo = createRandomDetectorProfile(); String errorPreMerge = detectorProfileOne.getError(); detectorProfileOne.merge(detectorProfileTwo); assertTrue(detectorProfileOne.toString().contains(detectorProfileTwo.getError())); @@ -79,7 +85,7 @@ public void testMergeDetectorProfile() { } public void testDetectorProfileToXContent() throws IOException { - DetectorProfile detectorProfile = createRandomDetectorProfile(); + ConfigProfile detectorProfile = createRandomDetectorProfile(); String detectorProfileString = TestHelpers.xContentBuilderToString(detectorProfile.toXContent(TestHelpers.builder())); XContentParser parser = TestHelpers.parser(detectorProfileString); Map parsedMap = parser.map(); @@ -89,22 +95,22 @@ public void testDetectorProfileToXContent() throws IOException { } public void testDetectorProfileName() throws IllegalArgumentException { - assertEquals("ad_task", DetectorProfileName.getName(ADCommonName.AD_TASK).getName()); - assertEquals("state", DetectorProfileName.getName(ADCommonName.STATE).getName()); - assertEquals("error", DetectorProfileName.getName(ADCommonName.ERROR).getName()); - assertEquals("coordinating_node", DetectorProfileName.getName(ADCommonName.COORDINATING_NODE).getName()); - assertEquals("shingle_size", DetectorProfileName.getName(ADCommonName.SHINGLE_SIZE).getName()); - assertEquals("total_size_in_bytes", DetectorProfileName.getName(ADCommonName.TOTAL_SIZE_IN_BYTES).getName()); - assertEquals("models", DetectorProfileName.getName(ADCommonName.MODELS).getName()); - assertEquals("init_progress", DetectorProfileName.getName(ADCommonName.INIT_PROGRESS).getName()); - assertEquals("total_entities", DetectorProfileName.getName(ADCommonName.TOTAL_ENTITIES).getName()); - assertEquals("active_entities", DetectorProfileName.getName(ADCommonName.ACTIVE_ENTITIES).getName()); - IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> DetectorProfileName.getName("abc")); + assertEquals("ad_task", ProfileName.getName(ADCommonName.AD_TASK).getName()); + assertEquals("state", ProfileName.getName(CommonName.STATE).getName()); + assertEquals("error", ProfileName.getName(CommonName.ERROR).getName()); + assertEquals("coordinating_node", ProfileName.getName(CommonName.COORDINATING_NODE).getName()); + assertEquals("shingle_size", ProfileName.getName(CommonName.SHINGLE_SIZE).getName()); + assertEquals("total_size_in_bytes", ProfileName.getName(CommonName.TOTAL_SIZE_IN_BYTES).getName()); + assertEquals("models", ProfileName.getName(CommonName.MODELS).getName()); + assertEquals("init_progress", ProfileName.getName(CommonName.INIT_PROGRESS).getName()); + assertEquals("total_entities", ProfileName.getName(CommonName.TOTAL_ENTITIES).getName()); + assertEquals("active_entities", ProfileName.getName(CommonName.ACTIVE_ENTITIES).getName()); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> ProfileName.getName("abc")); assertEquals(exception.getMessage(), ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); } public void testDetectorProfileSet() throws IllegalArgumentException { - DetectorProfile detectorProfileOne = createRandomDetectorProfile(); + ConfigProfile detectorProfileOne = createRandomDetectorProfile(); detectorProfileOne.setShingleSize(20); assertEquals(20, detectorProfileOne.getShingleSize()); detectorProfileOne.setActiveEntities(10L); diff --git a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java index 24cb0c879..18be64d54 100644 --- a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java @@ -18,8 +18,8 @@ import java.util.List; import org.junit.Test; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.StatsResponse; public class EntityAnomalyResultTests extends OpenSearchTestCase { @@ -90,7 +90,7 @@ public void testMerge_self() { @Test public void testMerge_otherClass() { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { diff --git a/src/test/java/org/opensearch/ad/model/EntityProfileTests.java b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java index 18e179145..f647da3ac 100644 --- a/src/test/java/org/opensearch/ad/model/EntityProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java @@ -16,10 +16,12 @@ import java.io.IOException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.EntityState; import test.org.opensearch.ad.util.JsonDeserializer; @@ -39,7 +41,7 @@ public void testToXContent() throws IOException, JsonPathNotFoundException { profile1.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = builder.toString(); - assertEquals("INIT", JsonDeserializer.getTextValue(json, ADCommonName.STATE)); + assertEquals("INIT", JsonDeserializer.getTextValue(json, CommonName.STATE)); EntityProfile profile2 = new EntityProfile(null, -1, -1, null, null, EntityState.UNKNOWN); @@ -47,7 +49,7 @@ public void testToXContent() throws IOException, JsonPathNotFoundException { profile2.toXContent(builder, ToXContent.EMPTY_PARAMS); json = builder.toString(); - assertTrue(false == JsonDeserializer.hasChildNode(json, ADCommonName.STATE)); + assertTrue(false == JsonDeserializer.hasChildNode(json, CommonName.STATE)); } public void testToXContentTimeStampAboveZero() throws IOException, JsonPathNotFoundException { @@ -57,7 +59,7 @@ public void testToXContentTimeStampAboveZero() throws IOException, JsonPathNotFo profile1.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = builder.toString(); - assertEquals("INIT", JsonDeserializer.getTextValue(json, ADCommonName.STATE)); + assertEquals("INIT", JsonDeserializer.getTextValue(json, CommonName.STATE)); EntityProfile profile2 = new EntityProfile(null, 1, 1, null, null, EntityState.UNKNOWN); @@ -65,6 +67,6 @@ public void testToXContentTimeStampAboveZero() throws IOException, JsonPathNotFo profile2.toXContent(builder, ToXContent.EMPTY_PARAMS); json = builder.toString(); - assertTrue(false == JsonDeserializer.hasChildNode(json, ADCommonName.STATE)); + assertTrue(false == JsonDeserializer.hasChildNode(json, CommonName.STATE)); } } diff --git a/src/test/java/org/opensearch/ad/model/ModelProfileTests.java b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java index c99ff6222..b5c9c852a 100644 --- a/src/test/java/org/opensearch/ad/model/ModelProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java @@ -20,6 +20,7 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; import test.org.opensearch.ad.util.JsonDeserializer; diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java index 830ac3f65..eb9d51da7 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java @@ -27,40 +27,45 @@ import java.util.Optional; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckPointMaintainRequestAdapterTests extends AbstractRateLimitingTest { - private CacheProvider cache; - private CheckpointDao checkpointDao; + private ADCacheProvider cache; + private ADCheckpointDao checkpointDao; private String indexName; private Setting checkpointInterval; private CheckPointMaintainRequestAdapter adapter; - private ModelState state; + private ModelState state; private CheckpointMaintainRequest request; private ClusterService clusterService; @Override public void setUp() throws Exception { super.setUp(); - cache = mock(CacheProvider.class); - checkpointDao = mock(CheckpointDao.class); + cache = mock(ADCacheProvider.class); + checkpointDao = mock(ADCheckpointDao.class); indexName = ADCommonName.CHECKPOINT_INDEX_NAME; checkpointInterval = AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ; - EntityCache entityCache = mock(EntityCache.class); + ADPriorityCache entityCache = mock(ADPriorityCache.class); when(cache.get()).thenReturn(entityCache); state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(entityCache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); @@ -71,13 +76,13 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(settings); adapter = new CheckPointMaintainRequestAdapter( - cache, checkpointDao, indexName, checkpointInterval, clock, clusterService, - Settings.EMPTY + Settings.EMPTY, + cache ); request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity.getModelId(detectorId).get()); diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java index 0d05259fc..04913cc6c 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java @@ -32,12 +32,12 @@ import java.util.Optional; import java.util.Random; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -45,19 +45,25 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointMaintainWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - CheckpointMaintainWorker cpMaintainWorker; - CheckpointWriteWorker writeWorker; + ADCheckpointMaintainWorker cpMaintainWorker; + ADCheckpointWriteWorker writeWorker; CheckpointMaintainRequest request; CheckpointMaintainRequest request2; List requests; - CheckpointDao checkpointDao; + ADCheckpointDao checkpointDao; @Override public void setUp() throws Exception { @@ -81,30 +87,34 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - writeWorker = mock(CheckpointWriteWorker.class); + writeWorker = mock(ADCheckpointWriteWorker.class); + + ADCacheProvider adCacheProvider = new ADCacheProvider(); - CacheProvider cache = mock(CacheProvider.class); - checkpointDao = mock(CheckpointDao.class); + ADPriorityCache cache = mock(ADPriorityCache.class); + checkpointDao = mock(ADCheckpointDao.class); String indexName = ADCommonName.CHECKPOINT_INDEX_NAME; Setting checkpointInterval = AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ; - EntityCache entityCache = mock(EntityCache.class); - when(cache.get()).thenReturn(entityCache); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - when(entityCache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); - CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( - cache, - checkpointDao, - indexName, - checkpointInterval, - clock, - clusterService, - settings - ); + + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + when(cache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); + adCacheProvider.set(cache); + CheckPointMaintainRequestAdapter adapter = + new CheckPointMaintainRequestAdapter<>( + checkpointDao, + indexName, + checkpointInterval, + clock, + clusterService, + settings, + adCacheProvider + ); // Integer.MAX_VALUE makes a huge heap - cpMaintainWorker = new CheckpointMaintainWorker( + cpMaintainWorker = new ADCheckpointMaintainWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -119,7 +129,7 @@ public void setUp() throws Exception { writeWorker, TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, - adapter + adapter::convert ); request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity.getModelId(detectorId).get()); diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java index 41b8035b0..ff9690ba9 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -11,10 +11,8 @@ package org.opensearch.ad.ratelimit; -import static java.util.AbstractMap.SimpleImmutableEntry; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -45,21 +43,18 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -75,31 +70,37 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.fasterxml.jackson.core.JsonParseException; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointReadWorkerTests extends AbstractRateLimitingTest { - CheckpointReadWorker worker; + ADCheckpointReadWorker worker; - CheckpointDao checkpoint; + ADCheckpointDao checkpoint; ClusterService clusterService; - ModelState state; + ModelState state; - CheckpointWriteWorker checkpointWriteQueue; - ModelManager modelManager; - EntityColdStartWorker coldstartQueue; - ResultWriteWorker resultWriteQueue; + ADCheckpointWriteWorker checkpointWriteQueue; + ADModelManager modelManager; + ADColdStartWorker coldstartQueue; + ADSaveResultStrategy resultWriteStrategy; ADIndexManagement anomalyDetectionIndices; - CacheProvider cacheProvider; - EntityCache entityCache; - EntityFeatureRequest request, request2, request3; + Provider cacheProvider; + ADPriorityCache entityCache; + FeatureRequest request, request2, request3; ClusterSettings clusterSettings; ADStats adStats; @@ -125,38 +126,36 @@ public void setUp() throws Exception { state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); - Map.Entry entry = new SimpleImmutableEntry(state.getModel(), Instant.now()); - when(checkpoint.processGetResponse(any(), anyString())).thenReturn(Optional.of(entry)); + when(checkpoint.processHCGetResponse(any(), anyString(), anyString())).thenReturn(state); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); - modelManager = mock(ModelManager.class); - when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); - when(modelManager.score(any(), anyString(), any())).thenReturn(new ThresholdingResult(0, 1, 0.7)); + modelManager = mock(ADModelManager.class); + when(modelManager.getResult(any(), any(), anyString(), any(), any(), anyString())).thenReturn(new ThresholdingResult(0, 1, 0.7)); - coldstartQueue = mock(EntityColdStartWorker.class); - resultWriteQueue = mock(ResultWriteWorker.class); + coldstartQueue = mock(ADColdStartWorker.class); + resultWriteStrategy = mock(ADSaveResultStrategy.class); anomalyDetectionIndices = mock(ADIndexManagement.class); - cacheProvider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + cacheProvider = new ADCacheProvider(); + entityCache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(entityCache); when(entityCache.hostIfPossible(any(), any())).thenReturn(true); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); // Integer.MAX_VALUE makes a huge heap - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -172,18 +171,18 @@ public void setUp() throws Exception { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); - request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity, new double[] { 0 }, 0); - request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); - request3 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity3, new double[] { 0 }, 0); + request = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity, null); + request2 = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity2, null); + request3 = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity3, null); } static class RegularSetUpConfig { @@ -232,16 +231,15 @@ private void regularTestSetUp(RegularSetUpConfig config) { when(entityCache.hostIfPossible(any(), any())).thenReturn(config.canHostModel); state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(config.fullModel).build()); - when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); + when(checkpoint.processHCGetResponse(any(), anyString(), anyString())).thenReturn(state); + if (config.fullModel) { - when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 1, 1)); + when(modelManager.getResult(any(), any(), anyString(), any(), any(), anyString())).thenReturn(new ThresholdingResult(0, 1, 1)); } else { - when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 0, 0)); + when(modelManager.getResult(any(), any(), anyString(), any(), any(), anyString())).thenReturn(new ThresholdingResult(0, 0, 0)); } - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); worker.putAll(requests); } @@ -249,20 +247,20 @@ private void regularTestSetUp(RegularSetUpConfig config) { public void testRegular() { regularTestSetUp(new RegularSetUpConfig.Builder().build()); - verify(resultWriteQueue, times(1)).put(any()); + verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); } public void testCannotLoadModel() { regularTestSetUp(new RegularSetUpConfig.Builder().canHostModel(false).build()); - verify(resultWriteQueue, times(1)).put(any()); + verify(resultWriteStrategy, times(1)).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, times(1)).write(any(), anyBoolean(), any()); } public void testNoFullModel() { regularTestSetUp(new RegularSetUpConfig.Builder().fullModel(false).build()); - verify(resultWriteQueue, never()).put(any()); + verify(resultWriteStrategy, never()).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); } @@ -327,7 +325,7 @@ public void testAllDocNotFound() { return null; }).when(checkpoint).batchRead(any(), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -366,7 +364,7 @@ public void testSingleDocNotFound() { return null; }).when(checkpoint).batchRead(any(), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -436,7 +434,7 @@ public void testTimeout() { return null; }).when(checkpoint).batchRead(any(), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -533,9 +531,9 @@ public void testRemoveUnusedQueues() { ExecutorService executorService = mock(ExecutorService.class); when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + 1, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -551,19 +549,19 @@ public void testRemoveUnusedQueues() { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); regularTestSetUp(new RegularSetUpConfig.Builder().build()); assertTrue(!worker.isQueueEmpty()); - assertEquals(CheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); + assertEquals(ADCheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); // make RequestQueue.expired return true when(clock.instant()).thenReturn(Instant.now().plusSeconds(TimeSeriesSettings.HOURLY_MAINTENANCE.getSeconds() + 1)); @@ -585,7 +583,7 @@ public void testSettingUpdatable() { maintenanceSetup(); // can host two requests in the queue - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( 2000, 1, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, @@ -603,16 +601,16 @@ public void testSettingUpdatable() { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -638,9 +636,9 @@ public void testOpenCircuitBreaker() { CircuitBreakerService breaker = mock(CircuitBreakerService.class); when(breaker.isOpen()).thenReturn(true); - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + 1, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -656,16 +654,16 @@ public void testOpenCircuitBreaker() { modelManager, checkpoint, coldstartQueue, - resultWriteQueue, nodeStateManager, anomalyDetectionIndices, cacheProvider, TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - adStats + adStats, + resultWriteStrategy ); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request2); worker.putAll(requests); @@ -713,23 +711,24 @@ public void testChangePriority() { } public void testDetectorId() { - assertEquals(detectorId, request.getId()); + assertEquals(detectorId, request.getConfigId()); String newDetectorId = "456"; request.setDetectorId(newDetectorId); - assertEquals(newDetectorId, request.getId()); + assertEquals(newDetectorId, request.getConfigId()); } @SuppressWarnings("unchecked") public void testHostException() throws IOException { String detectorId2 = "456"; Entity entity4 = Entity.createSingleAttributeEntity(categoryField, "value4"); - EntityFeatureRequest request4 = new EntityFeatureRequest( + FeatureRequest request4 = new FeatureRequest( Integer.MAX_VALUE, detectorId2, RequestPriority.MEDIUM, - entity4, new double[] { 0 }, - 0 + 0, + entity4, + null ); AnomalyDetector detector2 = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId2, Arrays.asList(categoryField)); @@ -777,7 +776,7 @@ public void testHostException() throws IOException { doThrow(LimitExceededException.class).when(entityCache).hostIfPossible(eq(detector2), any()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); requests.add(request4); worker.putAll(requests); @@ -803,17 +802,17 @@ public void testFailToScore() { }).when(checkpoint).batchRead(any(), any()); state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); - doThrow(new IllegalArgumentException()).when(modelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); + when(checkpoint.processHCGetResponse(any(), anyString(), anyString())).thenReturn(state); + doThrow(new IllegalArgumentException()).when(modelManager).getResult(any(), any(), anyString(), any(), any(), anyString()); - List requests = new ArrayList<>(); + List requests = new ArrayList<>(); requests.add(request); worker.putAll(requests); - verify(resultWriteQueue, never()).put(any()); + verify(resultWriteStrategy, never()).saveResult(any(), any(), any(), anyString()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); verify(coldstartQueue, times(1)).put(any()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); } } diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java index be83484ee..425b5e973 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java @@ -46,9 +46,7 @@ import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; @@ -65,18 +63,22 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointWriteWorkerTests extends AbstractRateLimitingTest { - CheckpointWriteWorker worker; + ADCheckpointWriteWorker worker; - CheckpointDao checkpoint; + ADCheckpointDao checkpoint; ClusterService clusterService; - ModelState state; + ModelState state; @Override @SuppressWarnings("unchecked") @@ -99,14 +101,14 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); Map checkpointMap = new HashMap<>(); checkpointMap.put(CommonName.FIELD_MODEL, "a"); when(checkpoint.toIndexSource(any())).thenReturn(checkpointMap); when(checkpoint.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); // Integer.MAX_VALUE makes a huge heap - worker = new CheckpointWriteWorker( + worker = new ADCheckpointWriteWorker( Integer.MAX_VALUE, TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, @@ -166,7 +168,7 @@ public void testTriggerSaveAll() { return null; }).when(checkpoint).batchWrite(any(), any()); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); @@ -210,7 +212,7 @@ public void testTriggerAutoFlush() throws InterruptedException { // Integer.MAX_VALUE makes a huge heap // create a worker to use mockThreadPool - worker = new CheckpointWriteWorker( + worker = new ADCheckpointWriteWorker( Integer.MAX_VALUE, TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, @@ -239,7 +241,7 @@ public void testTriggerAutoFlush() throws InterruptedException { // CHECKPOINT_WRITE_QUEUE_BATCH_SIZE is the largest batch size int numberOfRequests = 2 * AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.getDefault(Settings.EMPTY) + 1; for (int i = 0; i < numberOfRequests; i++) { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); worker.write(state, true, RequestPriority.MEDIUM); } @@ -268,7 +270,7 @@ public void testOverloaded() { worker.write(state, true, RequestPriority.MEDIUM); verify(checkpoint, times(1)).batchWrite(any(), any()); - verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchRejectedExecutionException.class)); + verify(nodeStateManager, times(1)).setException(eq(state.getConfigId()), any(OpenSearchRejectedExecutionException.class)); } public void testRetryException() { @@ -282,7 +284,7 @@ public void testRetryException() { worker.write(state, true, RequestPriority.MEDIUM); // we don't retry checkpoint write verify(checkpoint, times(1)).batchWrite(any(), any()); - verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchStatusException.class)); + verify(nodeStateManager, times(1)).setException(eq(state.getConfigId()), any(OpenSearchStatusException.class)); } /** @@ -310,7 +312,7 @@ public void testFailedRequest() { @SuppressWarnings("unchecked") public void testEmptyTimeStamp() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.MIN); worker.write(state, false, RequestPriority.MEDIUM); @@ -319,7 +321,7 @@ public void testEmptyTimeStamp() { @SuppressWarnings("unchecked") public void testTooSoonToSaveSingleWrite() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); worker.write(state, false, RequestPriority.MEDIUM); @@ -328,10 +330,10 @@ public void testTooSoonToSaveSingleWrite() { @SuppressWarnings("unchecked") public void testTooSoonToSaveWriteAll() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, false, RequestPriority.MEDIUM); @@ -341,7 +343,7 @@ public void testTooSoonToSaveWriteAll() { @SuppressWarnings("unchecked") public void testEmptyModel() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); when(state.getModel()).thenReturn(null); worker.write(state, true, RequestPriority.MEDIUM); @@ -351,11 +353,11 @@ public void testEmptyModel() { @SuppressWarnings("unchecked") public void testEmptyModelId() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - EntityModel model = mock(EntityModel.class); - when(state.getModel()).thenReturn(model); - when(state.getId()).thenReturn("1"); + ThresholdedRandomCutForest model = mock(ThresholdedRandomCutForest.class); + when(state.getModel()).thenReturn(Optional.of(model)); + when(state.getConfigId()).thenReturn("1"); when(state.getModelId()).thenReturn(null); worker.write(state, true, RequestPriority.MEDIUM); @@ -364,11 +366,11 @@ public void testEmptyModelId() { @SuppressWarnings("unchecked") public void testEmptyDetectorId() { - ModelState state = mock(ModelState.class); + ModelState state = mock(ModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - EntityModel model = mock(EntityModel.class); - when(state.getModel()).thenReturn(model); - when(state.getId()).thenReturn(null); + ThresholdedRandomCutForest model = mock(ThresholdedRandomCutForest.class); + when(state.getModel()).thenReturn(Optional.of(model)); + when(state.getConfigId()).thenReturn(null); when(state.getModelId()).thenReturn("a"); worker.write(state, true, RequestPriority.MEDIUM); @@ -395,7 +397,7 @@ public void testDetectorNotAvailableWriteAll() { return null; }).when(nodeStateManager).getConfig(any(String.class), eq(AnalysisType.AD), any(ActionListener.class)); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); verify(checkpoint, never()).batchWrite(any(), any()); diff --git a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java index d093f20ae..96d176a7f 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java @@ -33,14 +33,16 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; public class ColdEntityWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - ColdEntityWorker coldWorker; - CheckpointReadWorker readWorker; - EntityFeatureRequest request, request2, invalidRequest; - List requests; + ADColdEntityWorker coldWorker; + ADCheckpointReadWorker readWorker; + FeatureRequest request, request2, invalidRequest; + List requests; @Override public void setUp() throws Exception { @@ -63,12 +65,12 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - readWorker = mock(CheckpointReadWorker.class); + readWorker = mock(ADCheckpointReadWorker.class); // Integer.MAX_VALUE makes a huge heap - coldWorker = new ColdEntityWorker( + coldWorker = new ADColdEntityWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -85,9 +87,9 @@ public void setUp() throws Exception { nodeStateManager ); - request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity, new double[] { 0 }, 0); - request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity2, new double[] { 0 }, 0); - invalidRequest = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); + request = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, new double[] { 0 }, 0, entity, null); + request2 = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, new double[] { 0 }, 0, entity2, null); + invalidRequest = new FeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, new double[] { 0 }, 0, entity2, null); requests = new ArrayList<>(); requests.add(request); @@ -154,9 +156,9 @@ public void testDelay() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); // Integer.MAX_VALUE makes a huge heap - coldWorker = new ColdEntityWorker( + coldWorker = new ADColdEntityWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), diff --git a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java index 9fdf5a396..8bf52ecf1 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java @@ -24,14 +24,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; -import java.util.Optional; import java.util.Random; import org.opensearch.OpenSearchStatusException; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -40,15 +38,20 @@ import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.core.rest.RestStatus; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + import test.org.opensearch.ad.util.MLUtil; public class EntityColdStartWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - EntityColdStartWorker worker; - EntityColdStarter entityColdStarter; - CacheProvider cacheProvider; + ADColdStartWorker worker; + ADEntityColdStart entityColdStarter; + ADPriorityCache cacheProvider; @Override public void setUp() throws Exception { @@ -69,14 +72,14 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - entityColdStarter = mock(EntityColdStarter.class); + entityColdStarter = mock(ADEntityColdStart.class); - cacheProvider = mock(CacheProvider.class); + cacheProvider = mock(ADPriorityCache.class); // Integer.MAX_VALUE makes a huge heap - worker = new EntityColdStartWorker( + worker = new ADColdStartWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), @@ -92,21 +95,31 @@ public void setUp() throws Exception { entityColdStarter, TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, - cacheProvider + cacheProvider, + mock(ADModelManager.class), + mock(ADSaveResultStrategy.class) ); } public void testEmptyModelId() { - EntityRequest request = mock(EntityRequest.class); + FeatureRequest request = mock(FeatureRequest.class); when(request.getPriority()).thenReturn(RequestPriority.LOW); - when(request.getModelId()).thenReturn(Optional.empty()); + when(request.getModelId()).thenReturn(null); worker.put(request); verify(entityColdStarter, never()).trainModel(any(), anyString(), any(), any()); verify(request, times(1)).getModelId(); } public void testOverloaded() { - EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + FeatureRequest request = new FeatureRequest( + Integer.MAX_VALUE, + detectorId, + RequestPriority.MEDIUM, + new double[] { 0 }, + 0, + entity, + null + ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -126,7 +139,15 @@ public void testOverloaded() { } public void testException() { - EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + FeatureRequest request = new FeatureRequest( + Integer.MAX_VALUE, + detectorId, + RequestPriority.MEDIUM, + new double[] { 0 }, + 0, + entity, + null + ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -147,13 +168,21 @@ public void testException() { } public void testModelHosted() { - EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + FeatureRequest request = new FeatureRequest( + Integer.MAX_VALUE, + detectorId, + RequestPriority.MEDIUM, + new double[] { 0 }, + 0, + entity, + null + ); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - ModelState state = invocation.getArgument(2); - state.setModel(MLUtil.createNonEmptyModel(detectorId)); + ModelState state = invocation.getArgument(2); + state.setModel(MLUtil.createNonEmptyModel(detectorId).getLeft()); listener.onResponse(null); return null; @@ -161,6 +190,6 @@ public void testModelHosted() { worker.put(request); - verify(cacheProvider, times(1)).get(); + verify(cacheProvider, times(1)).get(anyString(), any()); } } diff --git a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java index 304a942c7..2d829722b 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java @@ -37,8 +37,7 @@ import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -49,13 +48,15 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.ResultBulkResponse; import org.opensearch.timeseries.util.RestHandlerUtils; public class ResultWriteWorkerTests extends AbstractRateLimitingTest { - ResultWriteWorker resultWriteQueue; + ADResultWriteWorker resultWriteQueue; ClusterService clusterService; - MultiEntityResultHandler resultHandler; + ADIndexMemoryPressureAwareResultHandler resultHandler; AnomalyResult detectResult; @Override @@ -82,9 +83,9 @@ public void setUp() throws Exception { threadPool = mock(ThreadPool.class); setUpADThreadPool(threadPool); - resultHandler = mock(MultiEntityResultHandler.class); + resultHandler = mock(ADIndexMemoryPressureAwareResultHandler.class); - resultWriteQueue = new ResultWriteWorker( + resultWriteQueue = new ADResultWriteWorker( Integer.MAX_VALUE, TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, @@ -111,10 +112,10 @@ public void setUp() throws Exception { public void testRegular() { List retryRequests = new ArrayList<>(); - ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + ResultBulkResponse resp = new ResultBulkResponse(retryRequests); ADResultBulkRequest request = new ADResultBulkRequest(); - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -124,12 +125,12 @@ public void testRegular() { request.add(resultWriteRequest); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onResponse(resp); return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // the request results one flush verify(resultHandler, times(1)).flush(any(), any()); @@ -143,10 +144,10 @@ public void testSingleRetryRequest() throws IOException { retryRequests.add(indexRequest); } - ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + ResultBulkResponse resp = new ResultBulkResponse(retryRequests); ADResultBulkRequest request = new ADResultBulkRequest(); - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -157,9 +158,9 @@ public void testSingleRetryRequest() throws IOException { final AtomicBoolean retried = new AtomicBoolean(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); if (retried.get()) { - listener.onResponse(new ADResultBulkResponse()); + listener.onResponse(new ResultBulkResponse()); } else { retried.set(true); listener.onResponse(resp); @@ -167,7 +168,7 @@ public void testSingleRetryRequest() throws IOException { return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // one flush from the original request; and one due to retry verify(resultHandler, times(2)).flush(any(), any()); @@ -176,9 +177,9 @@ public void testSingleRetryRequest() throws IOException { public void testRetryException() { final AtomicBoolean retried = new AtomicBoolean(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); if (retried.get()) { - listener.onResponse(new ADResultBulkResponse()); + listener.onResponse(new ResultBulkResponse()); } else { retried.set(true); listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); @@ -187,7 +188,7 @@ public void testRetryException() { return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // one flush from the original request; and one due to retry verify(resultHandler, times(2)).flush(any(), any()); verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchStatusException.class)); @@ -195,13 +196,13 @@ public void testRetryException() { public void testOverloaded() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); return null; }).when(resultHandler).flush(any(), any()); - resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + resultWriteQueue.put(new ADResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); // one flush from the original request; and one due to retry verify(resultHandler, times(1)).flush(any(), any()); verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchRejectedExecutionException.class)); diff --git a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java index 3411f37ac..b207faf1b 100644 --- a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java +++ b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java @@ -43,10 +43,13 @@ import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TimeSeriesTask; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -211,7 +214,8 @@ public static Response createAnomalyDetector( categoryFields, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); if (historical) { @@ -260,12 +264,12 @@ public static List searchLatestAdTaskOfDetector(RestClient client, Strin for (Object adTaskResponse : adTaskResponses) { String id = (String) ((Map) adTaskResponse).get("_id"); Map source = (Map) ((Map) adTaskResponse).get("_source"); - String state = (String) source.get(ADTask.STATE_FIELD); + String state = (String) source.get(TimeSeriesTask.STATE_FIELD); String parsedDetectorId = (String) source.get(ADTask.DETECTOR_ID_FIELD); - Double taskProgress = (Double) source.get(ADTask.TASK_PROGRESS_FIELD); - Double initProgress = (Double) source.get(ADTask.INIT_PROGRESS_FIELD); - String parsedTaskType = (String) source.get(ADTask.TASK_TYPE_FIELD); - String coordinatingNode = (String) source.get(ADTask.COORDINATING_NODE_FIELD); + Double taskProgress = (Double) source.get(TimeSeriesTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) source.get(TimeSeriesTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) source.get(TimeSeriesTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) source.get(TimeSeriesTask.COORDINATING_NODE_FIELD); ADTask adTask = ADTask .builder() .taskId(id) @@ -366,7 +370,8 @@ public static Map getDetectorWithJobAndTask(RestClient client, S Instant.ofEpochMilli(lastUpdateTime), null, null, - null + null, + AnalysisType.AD ); results.put(ANOMALY_DETECTOR_JOB, job); } @@ -387,13 +392,13 @@ public static Map getDetectorWithJobAndTask(RestClient client, S } private static ADTask parseAdTask(Map taskMap) { - String id = (String) taskMap.get(ADTask.TASK_ID_FIELD); - String state = (String) taskMap.get(ADTask.STATE_FIELD); + String id = (String) taskMap.get(TimeSeriesTask.TASK_ID_FIELD); + String state = (String) taskMap.get(TimeSeriesTask.STATE_FIELD); String parsedDetectorId = (String) taskMap.get(ADTask.DETECTOR_ID_FIELD); - Double taskProgress = (Double) taskMap.get(ADTask.TASK_PROGRESS_FIELD); - Double initProgress = (Double) taskMap.get(ADTask.INIT_PROGRESS_FIELD); - String parsedTaskType = (String) taskMap.get(ADTask.TASK_TYPE_FIELD); - String coordinatingNode = (String) taskMap.get(ADTask.COORDINATING_NODE_FIELD); + Double taskProgress = (Double) taskMap.get(TimeSeriesTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) taskMap.get(TimeSeriesTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) taskMap.get(TimeSeriesTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) taskMap.get(TimeSeriesTask.COORDINATING_NODE_FIELD); return ADTask .builder() .taskId(id) @@ -465,16 +470,16 @@ public static String startHistoricalAnalysis(RestClient client, String detectorI return taskId; } - public static ADTaskProfile waitUntilTaskDone(RestClient client, String detectorId) throws InterruptedException { + public static TaskProfile waitUntilTaskDone(RestClient client, String detectorId) throws InterruptedException { return waitUntilTaskReachState(client, detectorId, TestHelpers.HISTORICAL_ANALYSIS_DONE_STATS); } - public static ADTaskProfile waitUntilTaskReachState(RestClient client, String detectorId, Set targetStates) + public static TaskProfile waitUntilTaskReachState(RestClient client, String detectorId, Set targetStates) throws InterruptedException { int i = 0; int retryTimes = 200; - ADTaskProfile adTaskProfile = null; - while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getAdTask().getState())) && i < retryTimes) { + TaskProfile adTaskProfile = null; + while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getTask().getState())) && i < retryTimes) { try { adTaskProfile = getADTaskProfile(client, detectorId); } catch (Exception e) { @@ -488,7 +493,7 @@ public static ADTaskProfile waitUntilTaskReachState(RestClient client, String de return adTaskProfile; } - public static ADTaskProfile getADTaskProfile(RestClient client, String detectorId) throws IOException, ParseException { + public static TaskProfile getADTaskProfile(RestClient client, String detectorId) throws IOException, ParseException { Response profileResponse = TestHelpers .makeRequest( client, @@ -501,10 +506,10 @@ public static ADTaskProfile getADTaskProfile(RestClient client, String detectorI return parseADTaskProfile(profileResponse); } - public static ADTaskProfile parseADTaskProfile(Response profileResponse) throws IOException, ParseException { + public static TaskProfile parseADTaskProfile(Response profileResponse) throws IOException, ParseException { String profileResult = EntityUtils.toString(profileResponse.getEntity()); XContentParser parser = TestHelpers.parser(profileResult); - ADTaskProfile adTaskProfile = null; + TaskProfile adTaskProfile = null; while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index 084e2d44f..4c7ee5797 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -14,7 +14,6 @@ import static org.hamcrest.Matchers.containsString; import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.DUPLICATE_DETECTOR_MSG; import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import java.io.IOException; import java.time.Instant; @@ -37,7 +36,6 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; @@ -54,6 +52,7 @@ import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; @@ -145,7 +144,8 @@ public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); TestHelpers @@ -218,7 +218,8 @@ public void testUpdateAnomalyDetectorCategoryField() throws Exception { ImmutableList.of(randomAlphaOfLength(5)), detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); Exception ex = expectThrows( ResponseException.class, @@ -275,7 +276,8 @@ public void testUpdateAnomalyDetector() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); @@ -337,7 +339,8 @@ public void testUpdateAnomalyDetectorNameToExisting() throws Exception { null, detector1.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); TestHelpers @@ -376,7 +379,8 @@ public void testUpdateAnomalyDetectorNameToNew() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); TestHelpers @@ -422,7 +426,8 @@ public void testUpdateAnomalyDetectorWithNotExistingIndex() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); deleteIndexWithAdminClient(CommonName.CONFIG_INDEX); @@ -785,7 +790,8 @@ public void testUpdateAnomalyDetectorWithRunningAdJob() throws Exception { null, detector.getUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + 0.01d ); TestHelpers @@ -895,7 +901,7 @@ public void testStartAdJobWithNonexistingDetector() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - FAIL_TO_FIND_CONFIG_MSG, + CommonMessages.FAIL_TO_FIND_CONFIG_MSG, () -> TestHelpers .makeRequest( client(), @@ -997,7 +1003,7 @@ public void testStopNonExistingAdJob() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - FAIL_TO_FIND_CONFIG_MSG, + CommonMessages.FAIL_TO_FIND_CONFIG_MSG, () -> TestHelpers .makeRequest( client(), @@ -1332,7 +1338,7 @@ public void testValidateAnomalyDetectorOnWrongValidationType() throws Exception TestHelpers .assertFailWith( ResponseException.class, - ADCommonMessages.NOT_EXISTENT_VALIDATION_TYPE, + CommonMessages.NOT_EXISTENT_VALIDATION_TYPE, () -> TestHelpers .makeRequest( client(), @@ -1475,7 +1481,7 @@ public void testValidateAnomalyDetectorWithWrongCategoryField() throws Exception .extractValue("detector", responseMap); assertEquals( "non-existing category", - String.format(Locale.ROOT, AbstractAnomalyDetectorActionHandler.CATEGORY_NOT_FOUND_ERR_MSG, "host.keyword"), + String.format(Locale.ROOT, AbstractTimeSeriesActionHandler.CATEGORY_NOT_FOUND_ERR_MSG, "host.keyword"), messageMap.get("category_field").get("message") ); diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index e3881c968..2f9aa8751 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -17,8 +17,8 @@ import static org.opensearch.timeseries.TestHelpers.AD_BASE_STATS_URI; import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; import static org.opensearch.timeseries.stats.StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT; -import static org.opensearch.timeseries.stats.StatNames.MULTI_ENTITY_DETECTOR_COUNT; -import static org.opensearch.timeseries.stats.StatNames.SINGLE_ENTITY_DETECTOR_COUNT; +import static org.opensearch.timeseries.stats.StatNames.HC_DETECTOR_COUNT; +import static org.opensearch.timeseries.stats.StatNames.SINGLE_STREAM_DETECTOR_COUNT; import java.io.IOException; import java.util.List; @@ -39,6 +39,7 @@ import org.opensearch.client.ResponseException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; @@ -92,9 +93,9 @@ public void testHistoricalAnalysisForMultiCategoryHC() throws Exception { private void checkIfTaskCanFinishCorrectly(String detectorId, String taskId, Set states) throws InterruptedException { List results = waitUntilTaskDone(detectorId); - ADTaskProfile endTaskProfile = (ADTaskProfile) results.get(0); + TaskProfile endTaskProfile = (TaskProfile) results.get(0); Integer retryCount = (Integer) results.get(1); - ADTask stoppedAdTask = endTaskProfile.getAdTask(); + ADTask stoppedAdTask = endTaskProfile.getTask(); assertEquals(taskId, stoppedAdTask.getTaskId()); if (retryCount < MAX_RETRY_TIMES) { // It's possible that historical analysis still running after max retry times @@ -118,14 +119,14 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul // get task profile ADTaskProfile adTaskProfile = waitUntilGetTaskProfile(detectorId); if (categoryFieldSize > 0) { - if (!TaskState.RUNNING.name().equals(adTaskProfile.getAdTask().getState())) { + if (!TaskState.RUNNING.name().equals(adTaskProfile.getTask().getState())) { adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(TaskState.RUNNING.name())).get(0); } assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); assertTrue(adTaskProfile.getRunningEntitiesCount() > 0); } - ADTask adTask = adTaskProfile.getAdTask(); + ADTask adTask = adTaskProfile.getTask(); assertEquals(taskId, adTask.getTaskId()); assertTrue(TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(adTask.getState())); @@ -133,7 +134,7 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul Response statsResponse = TestHelpers.makeRequest(client(), "GET", AD_BASE_STATS_URI, ImmutableMap.of(), "", null); String statsResult = EntityUtils.toString(statsResponse.getEntity()); Map stringObjectMap = TestHelpers.parseStatsResult(statsResult); - String detectorCountState = categoryFieldSize > 0 ? MULTI_ENTITY_DETECTOR_COUNT.getName() : SINGLE_ENTITY_DETECTOR_COUNT.getName(); + String detectorCountState = categoryFieldSize > 0 ? HC_DETECTOR_COUNT.getName() : SINGLE_STREAM_DETECTOR_COUNT.getName(); assertTrue((long) stringObjectMap.get(detectorCountState) > 0); Map nodes = (Map) stringObjectMap.get("nodes"); long totalBatchTaskExecution = 0; @@ -317,7 +318,8 @@ private AnomalyDetector randomAnomalyDetector(AnomalyDetector detector) { detector.getCategoryFields(), detector.getUser(), detector.getCustomResultIndex(), - detector.getImputationOption() + detector.getImputationOption(), + 0.01d ); } diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java index dcadb41ac..e9c176de6 100644 --- a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java @@ -262,7 +262,8 @@ public void testUpdateApiFilterByEnabledForAdmin() throws IOException { ImmutableList.of(randomAlphaOfLength(5)) ), null, - aliceDetector.getImputationOption() + aliceDetector.getImputationOption(), + 0.01d ); // User client has admin all access, and has "opensearch" backend role so client should be able to update detector // But the detector's backend role should not be replaced as client's backend roles (all_access). @@ -309,7 +310,8 @@ public void testUpdateApiFilterByEnabled() throws IOException { ImmutableList.of(randomAlphaOfLength(5)) ), null, - aliceDetector.getImputationOption() + aliceDetector.getImputationOption(), + 0.01d ); enableFilterBy(); // User Fish has AD full access, and has "odfe" backend role which is one of Alice's backend role, so diff --git a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java index 59853ce62..ae18aa6c2 100644 --- a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java +++ b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java @@ -21,7 +21,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; import java.io.IOException; import java.util.Arrays; @@ -36,20 +35,18 @@ import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileResponse; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.client.Client; -import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.search.aggregations.AggregationBuilder; @@ -59,7 +56,11 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; @@ -69,12 +70,8 @@ public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCas private static ADIndexManagement anomalyDetectionIndices; private static String detectorId; - private static Long seqNo; - private static Long primaryTerm; private static NamedXContentRegistry xContentRegistry; - private static TransportService transportService; - private static TimeValue requestTimeout; private static DiscoveryNodeFilterer nodeFilter; private static AnomalyDetector detector; @@ -84,22 +81,20 @@ public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCas private ExecuteADResultResponseRecorder recorder; private Client client; - private IndexAnomalyDetectorJobActionHandler handler; - private AnomalyIndexHandler anomalyResultHandler; + private ADIndexJobActionHandler handler; + private ResultBulkIndexingHandler anomalyResultHandler; private NodeStateManager nodeStateManager; private ADTaskCacheManager adTaskCacheManager; + private TransportService transportService; @BeforeClass public static void setOnce() throws IOException { detectorId = "123"; - seqNo = 1L; - primaryTerm = 2L; anomalyDetectionIndices = mock(ADIndexManagement.class); xContentRegistry = NamedXContentRegistry.EMPTY; - transportService = mock(TransportService.class); - - requestTimeout = TimeValue.timeValueMinutes(60); when(anomalyDetectionIndices.doesJobIndexExist()).thenReturn(true); + // make sure getAndExecuteOnLatestConfigLevelTask called in startConfig + when(anomalyDetectionIndices.doesStateIndexExist()).thenReturn(true); nodeFilter = mock(DiscoveryNodeFilterer.class); detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); @@ -137,7 +132,7 @@ public void setUp() throws Exception { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[2]; - AnomalyResultResponse response = new AnomalyResultResponse(null, "", 0L, 10L, true); + AnomalyResultResponse response = new AnomalyResultResponse(null, "", 0L, 10L, true, null); listener.onResponse(response); return null; @@ -146,17 +141,17 @@ public void setUp() throws Exception { adTaskManager = mock(ADTaskManager.class); doAnswer(invocation -> { Object[] args = invocation.getArguments(); - ActionListener listener = (ActionListener) args[4]; + ActionListener listener = (ActionListener) args[4]; - AnomalyDetectorJobResponse response = mock(AnomalyDetectorJobResponse.class); + JobResponse response = mock(JobResponse.class); listener.onResponse(response); return null; - }).when(adTaskManager).startDetector(any(), any(), any(), any(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); threadPool = mock(ThreadPool.class); - anomalyResultHandler = mock(AnomalyIndexHandler.class); + anomalyResultHandler = mock(ResultBulkIndexingHandler.class); nodeStateManager = mock(NodeStateManager.class); @@ -175,31 +170,30 @@ public void setUp() throws Exception { 32 ); - handler = new IndexAnomalyDetectorJobActionHandler( + handler = new ADIndexJobActionHandler( client, anomalyDetectionIndices, - detectorId, - seqNo, - primaryTerm, - requestTimeout, xContentRegistry, - transportService, adTaskManager, - recorder + recorder, + nodeStateManager, + Settings.EMPTY ); + + transportService = mock(TransportService.class); } @SuppressWarnings("unchecked") public void testDelayHCProfile() { when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(false); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(1)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(threadPool, times(1)).schedule(any(), any(), any()); verify(listener, times(1)).onResponse(any()); @@ -216,17 +210,17 @@ public void testNoDelayHCProfile() { listener.onResponse(response); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(threadPool, never()).schedule(any(), any(), any()); @@ -242,17 +236,17 @@ public void testHCProfileException() { listener.onFailure(new RuntimeException()); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, never()).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(threadPool, never()).schedule(any(), any(), any()); @@ -270,7 +264,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNodeResourceNotFoundExcept listener.onResponse(response); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); @@ -278,18 +272,18 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNodeResourceNotFoundExcept Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[5]; - listener.onFailure(new ResourceNotFoundException(CAN_NOT_FIND_LATEST_TASK)); + listener.onFailure(new ResourceNotFoundException(CommonMessages.CAN_NOT_FIND_LATEST_TASK)); return null; }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(adTaskManager, times(1)).removeRealtimeTaskCache(anyString()); @@ -308,7 +302,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingException() { listener.onResponse(response); return null; - }).when(client).execute(any(ProfileAction.class), any(), any()); + }).when(client).execute(any(ADProfileAction.class), any(), any()); when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); @@ -321,17 +315,17 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingException() { return null; }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); - verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).getAndExecuteOnLatestConfigLevelTask(any(), any(), eq(false), any(), any(), any()); verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(adTaskManager, never()).removeRealtimeTaskCache(anyString()); - verify(adTaskManager, times(1)).skipUpdateHCRealtimeTask(anyString(), anyString()); + verify(adTaskManager, times(1)).skipUpdateRealtimeTask(anyString(), anyString()); verify(threadPool, never()).schedule(any(), any(), any()); verify(listener, times(1)).onResponse(any()); } @@ -347,7 +341,7 @@ public void testIndexException() throws IOException { return null; }).when(client).execute(any(AnomalyResultAction.class), any(), any()); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); AggregationBuilder aggregationBuilder = TestHelpers .parseAggregation("{\"test\":{\"max\":{\"field\":\"" + MockSimpleLog.VALUE_FIELD + "\"}}}"); Feature feature = new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); @@ -361,7 +355,7 @@ public void testIndexException() throws IOException { ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "index" ); when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, transportService, listener); verify(anomalyResultHandler, times(1)).index(any(), any(), eq(null)); verify(threadPool, times(1)).schedule(any(), any(), any()); } diff --git a/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java b/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java index 6de90a068..5e574e77d 100644 --- a/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java +++ b/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java @@ -19,6 +19,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.settings.TimeSeriesEnabledSetting; public class ADEnabledSettingTests extends OpenSearchTestCase { @@ -30,9 +31,9 @@ public void testIsADEnabled() { } public void testIsADBreakerEnabled() { - assertTrue(ADEnabledSetting.isADBreakerEnabled()); + assertTrue(TimeSeriesEnabledSetting.isBreakerEnabled()); ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.AD_BREAKER_ENABLED, false); - assertTrue(!ADEnabledSetting.isADBreakerEnabled()); + assertTrue(!TimeSeriesEnabledSetting.isBreakerEnabled()); } public void testIsInterpolationInColdStartEnabled() { diff --git a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java index 085ea5959..46cd0e619 100644 --- a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java +++ b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java @@ -155,7 +155,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(Settings.EMPTY) ); assertEquals( - TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY) ); assertEquals( @@ -163,7 +163,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES.get(Settings.EMPTY) ); assertEquals( - TimeSeriesSettings.BACKOFF_MINUTES.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_BACKOFF_MINUTES.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES.get(Settings.EMPTY) ); assertEquals( diff --git a/src/test/java/org/opensearch/ad/stats/ADStatTests.java b/src/test/java/org/opensearch/ad/stats/ADStatTests.java index 1912f92ad..7ec161f1b 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatTests.java @@ -14,32 +14,33 @@ import java.util.function.Supplier; import org.junit.Test; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; public class ADStatTests extends OpenSearchTestCase { @Test public void testIsClusterLevel() { - ADStat stat1 = new ADStat<>(true, new TestSupplier()); + TimeSeriesStat stat1 = new TimeSeriesStat<>(true, new TestSupplier()); assertTrue("isCluster returns the wrong value", stat1.isClusterLevel()); - ADStat stat2 = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat stat2 = new TimeSeriesStat<>(false, new TestSupplier()); assertTrue("isCluster returns the wrong value", !stat2.isClusterLevel()); } @Test public void testGetValue() { - ADStat stat1 = new ADStat<>(false, new CounterSupplier()); + TimeSeriesStat stat1 = new TimeSeriesStat<>(false, new CounterSupplier()); assertEquals("GetValue returns the incorrect value", 0L, (long) (stat1.getValue())); - ADStat stat2 = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat stat2 = new TimeSeriesStat<>(false, new TestSupplier()); assertEquals("GetValue returns the incorrect value", "test", stat2.getValue()); } @Test public void testSetValue() { - ADStat stat = new ADStat<>(false, new SettableSupplier()); + TimeSeriesStat stat = new TimeSeriesStat<>(false, new SettableSupplier()); assertEquals("GetValue returns the incorrect value", 0L, (long) (stat.getValue())); stat.setValue(10L); assertEquals("GetValue returns the incorrect value", 10L, (long) stat.getValue()); @@ -47,7 +48,7 @@ public void testSetValue() { @Test public void testIncrement() { - ADStat incrementStat = new ADStat<>(false, new CounterSupplier()); + TimeSeriesStat incrementStat = new TimeSeriesStat<>(false, new CounterSupplier()); for (Long i = 0L; i < 100; i++) { assertEquals("increment does not work", i, incrementStat.getValue()); @@ -55,7 +56,7 @@ public void testIncrement() { } // Ensure that no problems occur for a stat that cannot be incremented - ADStat nonIncStat = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat nonIncStat = new TimeSeriesStat<>(false, new TestSupplier()); nonIncStat.increment(); } diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java index 194623bd5..b47bd8a0d 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java @@ -19,18 +19,19 @@ import org.junit.Test; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.transport.ADStatsNodeResponse; -import org.opensearch.ad.transport.ADStatsNodesResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsResponse; public class ADStatsResponseTests extends OpenSearchTestCase { @Test public void testGetAndSetClusterStats() { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); adStatsResponse.setClusterStats(testClusterStats); @@ -39,53 +40,53 @@ public void testGetAndSetClusterStats() { @Test public void testGetAndSetADStatsNodesResponse() { - ADStatsResponse adStatsResponse = new ADStatsResponse(); - List responses = Collections.emptyList(); + StatsResponse adStatsResponse = new StatsResponse(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); - assertEquals(adStatsNodesResponse, adStatsResponse.getADStatsNodesResponse()); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setStatsNodesResponse(adStatsNodesResponse); + assertEquals(adStatsNodesResponse, adStatsResponse.getStatsNodesResponse()); } @Test public void testMerge() { - ADStatsResponse adStatsResponse1 = new ADStatsResponse(); + StatsResponse adStatsResponse1 = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); adStatsResponse1.setClusterStats(testClusterStats); - ADStatsResponse adStatsResponse2 = new ADStatsResponse(); - List responses = Collections.emptyList(); + StatsResponse adStatsResponse2 = new StatsResponse(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse2.setADStatsNodesResponse(adStatsNodesResponse); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse2.setStatsNodesResponse(adStatsNodesResponse); adStatsResponse1.merge(adStatsResponse2); assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getStatsNodesResponse()); adStatsResponse2.merge(adStatsResponse1); assertEquals(testClusterStats, adStatsResponse2.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse2.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse2.getStatsNodesResponse()); // Confirm merging with null does nothing adStatsResponse1.merge(null); assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getStatsNodesResponse()); // Confirm merging with self does nothing adStatsResponse1.merge(adStatsResponse1); assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); - assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getStatsNodesResponse()); } @Test public void testEquals() { - ADStatsResponse adStatsResponse1 = new ADStatsResponse(); + StatsResponse adStatsResponse1 = new StatsResponse(); assertEquals(adStatsResponse1, adStatsResponse1); assertNotEquals(null, adStatsResponse1); assertNotEquals(1, adStatsResponse1); - ADStatsResponse adStatsResponse2 = new ADStatsResponse(); + StatsResponse adStatsResponse2 = new StatsResponse(); assertEquals(adStatsResponse1, adStatsResponse2); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); @@ -95,8 +96,8 @@ public void testEquals() { @Test public void testHashCode() { - ADStatsResponse adStatsResponse1 = new ADStatsResponse(); - ADStatsResponse adStatsResponse2 = new ADStatsResponse(); + StatsResponse adStatsResponse1 = new StatsResponse(); + StatsResponse adStatsResponse2 = new StatsResponse(); assertEquals(adStatsResponse1.hashCode(), adStatsResponse2.hashCode()); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1L); @@ -106,14 +107,14 @@ public void testHashCode() { @Test public void testToXContent() throws IOException { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_stat", 1); adStatsResponse.setClusterStats(testClusterStats); - List responses = Collections.emptyList(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setStatsNodesResponse(adStatsNodesResponse); XContentBuilder builder = XContentFactory.jsonBuilder(); adStatsResponse.toXContent(builder); diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java index 6db1ac5cc..70de74167 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java @@ -17,6 +17,7 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; import java.time.Clock; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -24,36 +25,39 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.util.IndexUtils; import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class ADStatsTests extends OpenSearchTestCase { - private Map> statsMap; + private Map> statsMap; private ADStats adStats; private RandomCutForest rcf; private HybridThresholdingModel thresholdingModel; @@ -64,10 +68,10 @@ public class ADStatsTests extends OpenSearchTestCase { private Clock clock; @Mock - private ModelManager modelManager; + private ADModelManager modelManager; @Mock - private CacheProvider cacheProvider; + private ADCacheProvider cacheProvider; @Before public void setup() { @@ -80,20 +84,62 @@ public void setup() { List> modelsInformation = new ArrayList<>( Arrays .asList( - new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + new ModelState<>( + rcf, + "rcf-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + rcf, + "rcf-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ) ) ); when(modelManager.getAllModels()).thenReturn(modelsInformation); - ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel1 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); - EntityCache cache = mock(EntityCache.class); + List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + ADPriorityCache cache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(cache); when(cache.getAllModels()).thenReturn(entityModelsInformation); @@ -115,12 +161,15 @@ public void setup() { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - statsMap = new HashMap>() { + statsMap = new HashMap>() { { - put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); - put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); - put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + put(nodeStatName1, new TimeSeriesStat<>(false, new CounterSupplier())); + put( + nodeStatName2, + new TimeSeriesStat<>(false, new ADModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + ); + put(clusterStatName1, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); } }; @@ -134,11 +183,11 @@ public void testStatNamesGetNames() { @Test public void testGetStats() { - Map> stats = adStats.getStats(); + Map> stats = adStats.getStats(); assertEquals("getStats returns the incorrect number of stats", stats.size(), statsMap.size()); - for (Map.Entry> stat : stats.entrySet()) { + for (Map.Entry> stat : stats.entrySet()) { assertTrue( "getStats returns incorrect stats", adStats.getStats().containsKey(stat.getKey()) && adStats.getStats().get(stat.getKey()) == stat.getValue() @@ -148,7 +197,7 @@ public void testGetStats() { @Test public void testGetStat() { - ADStat stat = adStats.getStat(clusterStatName1); + TimeSeriesStat stat = adStats.getStat(clusterStatName1); assertTrue( "getStat returns incorrect stat", @@ -158,10 +207,10 @@ public void testGetStat() { @Test public void testGetNodeStats() { - Map> stats = adStats.getStats(); - Set> nodeStats = new HashSet<>(adStats.getNodeStats().values()); + Map> stats = adStats.getStats(); + Set> nodeStats = new HashSet<>(adStats.getNodeStats().values()); - for (ADStat stat : stats.values()) { + for (TimeSeriesStat stat : stats.values()) { assertTrue( "getNodeStats returns incorrect stat", (stat.isClusterLevel() && !nodeStats.contains(stat)) || (!stat.isClusterLevel() && nodeStats.contains(stat)) @@ -171,10 +220,10 @@ public void testGetNodeStats() { @Test public void testGetClusterStats() { - Map> stats = adStats.getStats(); - Set> clusterStats = new HashSet<>(adStats.getClusterStats().values()); + Map> stats = adStats.getStats(); + Set> clusterStats = new HashSet<>(adStats.getClusterStats().values()); - for (ADStat stat : stats.values()) { + for (TimeSeriesStat stat : stats.values()) { assertTrue( "getClusterStats returns incorrect stat", (stat.isClusterLevel() && clusterStats.contains(stat)) || (!stat.isClusterLevel() && !clusterStats.contains(stat)) diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java index 333d50ffe..3490e0318 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java @@ -13,6 +13,7 @@ import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; public class CounterSupplierTests extends OpenSearchTestCase { @Test diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java index cfdf71188..409437490 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java @@ -16,8 +16,9 @@ import org.junit.Before; import org.junit.Test; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.util.IndexUtils; public class IndexSupplierTests extends OpenSearchTestCase { private IndexUtils indexUtils; diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java index 21a9e4aff..9a57a58d4 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java @@ -14,15 +14,17 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; -import static org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; +import static org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; import java.time.Clock; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -30,18 +32,19 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -51,13 +54,13 @@ public class ModelsOnNodeSupplierTests extends OpenSearchTestCase { private HybridThresholdingModel thresholdingModel; private List> expectedResults; private Clock clock; - private List> entityModelsInformation; + private List> entityModelsInformation; @Mock - private ModelManager modelManager; + private ADModelManager modelManager; @Mock - private CacheProvider cacheProvider; + private ADCacheProvider cacheProvider; @Before public void setup() { @@ -70,20 +73,62 @@ public void setup() { expectedResults = new ArrayList<>( Arrays .asList( - new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + new ModelState<>( + rcf, + "rcf-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-1", + "detector-1", + ModelManager.ModelType.RCF.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + rcf, + "rcf-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ), + new ModelState<>( + thresholdingModel, + "thr-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f, + null, + Optional.empty(), + new ArrayDeque<>() + ) ) ); when(modelManager.getAllModels()).thenReturn(expectedResults); - ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel1 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); - EntityCache cache = mock(EntityCache.class); + ADPriorityCache cache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(cache); when(cache.getAllModels()).thenReturn(entityModelsInformation); } @@ -98,7 +143,7 @@ public void testGet() { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - ModelsOnNodeSupplier modelsOnNodeSupplier = new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService); + ADModelsOnNodeSupplier modelsOnNodeSupplier = new ADModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService); List> results = modelsOnNodeSupplier.get(); assertEquals( "get fails to return correct result", @@ -119,7 +164,7 @@ public void testGet() { @Test public void testGetModelCount() { - ModelsOnNodeCountSupplier modelsOnNodeSupplier = new ModelsOnNodeCountSupplier(modelManager, cacheProvider); + ADModelsOnNodeCountSupplier modelsOnNodeSupplier = new ADModelsOnNodeCountSupplier(modelManager, cacheProvider); assertEquals(6L, modelsOnNodeSupplier.get().longValue()); } } diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java index 1cf1c9306..821871984 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java @@ -13,6 +13,7 @@ import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; public class SettableSupplierTests extends OpenSearchTestCase { @Test diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index ad14b49c4..ed43baa73 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -246,7 +246,7 @@ public void testTopEntityInited() throws IOException { assertTrue(adTaskCacheManager.topEntityInited(detectorId)); } - public void testEntityCache() throws IOException { + public void testADPriorityCache() throws IOException { String detectorId = randomAlphaOfLength(10); assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); @@ -321,14 +321,14 @@ public void testRealtimeTaskCache() { adTaskCacheManager.updateRealtimeTaskCache(detectorId1, newState, newInitProgress, newError); assertFalse(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); - assertArrayEquals(new String[] { detectorId1 }, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()); + assertArrayEquals(new String[] { detectorId1 }, adTaskCacheManager.getConfigIdsInRealtimeTaskCache()); String detectorId2 = randomAlphaOfLength(10); adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); - assertEquals(1, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + assertEquals(1, adTaskCacheManager.getConfigIdsInRealtimeTaskCache().length); adTaskCacheManager.initRealtimeTaskCache(detectorId2, 60_000); adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); - assertEquals(2, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + assertEquals(2, adTaskCacheManager.getConfigIdsInRealtimeTaskCache().length); newState = TaskState.RUNNING.name(); newInitProgress = 1.0f; @@ -340,10 +340,10 @@ public void testRealtimeTaskCache() { assertEquals(newError, adTaskCacheManager.getRealtimeTaskCache(detectorId1).getError()); adTaskCacheManager.removeRealtimeTaskCache(detectorId1); - assertArrayEquals(new String[] { detectorId2 }, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()); + assertArrayEquals(new String[] { detectorId2 }, adTaskCacheManager.getConfigIdsInRealtimeTaskCache()); adTaskCacheManager.clearRealtimeTaskCache(); - assertEquals(0, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + assertEquals(0, adTaskCacheManager.getConfigIdsInRealtimeTaskCache().length); } diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index 9d09f7fee..184964343 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -45,10 +45,10 @@ import static org.opensearch.timeseries.TestHelpers.randomIntervalSchedule; import static org.opensearch.timeseries.TestHelpers.randomIntervalTimeConfiguration; import static org.opensearch.timeseries.TestHelpers.randomUser; -import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; import static org.opensearch.timeseries.model.Entity.createSingleAttributeEntity; import java.io.IOException; +import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; @@ -79,8 +79,9 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.ExecuteADResultResponseRecorder; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; @@ -88,13 +89,9 @@ import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; -import org.opensearch.ad.stats.InternalStatNames; -import org.opensearch.ad.transport.ADStatsNodeResponse; -import org.opensearch.ad.transport.ADStatsNodesResponse; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.transport.ADTaskProfileNodeResponse; import org.opensearch.ad.transport.ADTaskProfileResponse; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.ad.transport.ForwardADTaskRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -110,6 +107,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; @@ -121,15 +119,26 @@ import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.InternalStatNames; import org.opensearch.timeseries.task.RealtimeTaskCache; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -154,10 +163,10 @@ public class ADTaskManagerTests extends ADUnitTestCase { private TransportService transportService; private ADTaskManager adTaskManager; private ThreadPool threadPool; - private IndexAnomalyDetectorJobActionHandler indexAnomalyDetectorJobActionHandler; + private ADIndexJobActionHandler indexAnomalyDetectorJobActionHandler; private DateRange detectionDateRange; - private ActionListener listener; + private ActionListener listener; private DiscoveryNode node1; private DiscoveryNode node2; @@ -200,7 +209,10 @@ public class ADTaskManagerTests extends ADUnitTestCase { + ",\"parent_task_id\":\"a1civ3sBwF58XZxvKrko\",\"worker_node\":\"DL5uOJV3TjOOAyh5hJXrCA\",\"current_piece\"" + ":1630999260000,\"execution_end_time\":1630999442814}}"; @Captor - ArgumentCaptor> remoteResponseHandler; + ArgumentCaptor> remoteResponseHandler; + + NodeStateManager nodeStateManager; + ADTaskProfileRunner taskProfileRunner; @Override public void setUp() throws Exception { @@ -240,7 +252,8 @@ public void setUp() throws Exception { threadContext = new ThreadContext(settings); when(threadPool.getThreadContext()).thenReturn(threadContext); when(client.threadPool()).thenReturn(threadPool); - indexAnomalyDetectorJobActionHandler = mock(IndexAnomalyDetectorJobActionHandler.class); + nodeStateManager = mock(NodeStateManager.class); + taskProfileRunner = mock(ADTaskProfileRunner.class); adTaskManager = spy( new ADTaskManager( settings, @@ -251,13 +264,24 @@ public void setUp() throws Exception { nodeFilter, hashRing, adTaskCacheManager, - threadPool + threadPool, + nodeStateManager, + taskProfileRunner ) ); + indexAnomalyDetectorJobActionHandler = new ADIndexJobActionHandler( + client, + detectionIndices, + mock(NamedXContentRegistry.class), + adTaskManager, + mock(ExecuteADResultResponseRecorder.class), + nodeStateManager, + Settings.EMPTY + ); - listener = spy(new ActionListener() { + listener = spy(new ActionListener() { @Override - public void onResponse(AnomalyDetectorJobResponse bulkItemResponses) {} + public void onResponse(JobResponse bulkItemResponses) {} @Override public void onFailure(Exception e) {} @@ -313,7 +337,7 @@ private void setupHashRingWithSameLocalADVersionNodes() { Consumer function = invocation.getArgument(0); function.accept(new DiscoveryNode[] { node1, node2 }); return null; - }).when(hashRing).getNodesWithSameLocalAdVersion(any(), any()); + }).when(hashRing).getNodesWithSameLocalVersion(any(), any()); } private void setupHashRingWithOwningNode() { @@ -321,7 +345,7 @@ private void setupHashRingWithOwningNode() { Consumer> function = invocation.getArgument(1); function.accept(Optional.of(node1)); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(any(), any(), any()); } public void testCreateTaskIndexNotAcknowledged() throws IOException { @@ -334,9 +358,9 @@ public void testCreateTaskIndexNotAcknowledged() throws IOException { AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); setupGetDetector(detector); - adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, randomUser(), transportService, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); assertEquals(error, exceptionCaptor.getValue().getMessage()); } @@ -350,7 +374,7 @@ public void testCreateTaskIndexWithResourceAlreadyExistsException() throws IOExc AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); setupGetDetector(detector); - adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, randomUser(), transportService, listener); verify(listener, never()).onFailure(any()); } @@ -365,12 +389,12 @@ public void testCreateTaskIndexWithException() throws IOException { AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); setupGetDetector(detector); - adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, randomUser(), transportService, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals(error, exceptionCaptor.getValue().getMessage()); } - public void testStartDetectorWithNoEnabledFeature() throws IOException { + public void testgetAndExecuteOnLatestConfigLevelTaskWithNoEnabledFeature() throws IOException { AnomalyDetector detector = randomDetector( ImmutableList.of(randomFeature(false)), randomAlphaOfLength(5), @@ -379,16 +403,7 @@ public void testStartDetectorWithNoEnabledFeature() throws IOException { ); setupGetDetector(detector); - adTaskManager - .startDetector( - detector.getId(), - detectionDateRange, - indexAnomalyDetectorJobActionHandler, - randomUser(), - transportService, - context, - listener - ); + adTaskManager.startHistorical(detector, detectionDateRange, randomUser(), transportService, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); } @@ -398,29 +413,20 @@ public void testStartDetectorForHistoricalAnalysis() throws IOException { setupGetDetector(detector); setupHashRingWithOwningNode(); - adTaskManager - .startDetector( - detector.getId(), - detectionDateRange, - indexAnomalyDetectorJobActionHandler, - randomUser(), - transportService, - context, - listener - ); + adTaskManager.startHistorical(detector, detectionDateRange, randomUser(), transportService, listener); verify(adTaskManager, times(1)).forwardRequestToLeadNode(any(), any(), any()); } private void setupTaskSlots(int node1UsedTaskSlots, int node1AssignedTaskSLots, int node2UsedTaskSlots, int node2AssignedTaskSLots) { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener .onResponse( - new ADStatsNodesResponse( + new StatsNodesResponse( new ClusterName(randomAlphaOfLength(5)), ImmutableList .of( - new ADStatsNodeResponse( + new StatsNodeResponse( node1, ImmutableMap .of( @@ -430,7 +436,7 @@ private void setupTaskSlots(int node1UsedTaskSlots, int node1AssignedTaskSLots, node1AssignedTaskSLots ) ), - new ADStatsNodeResponse( + new StatsNodeResponse( node2, ImmutableMap .of( @@ -556,7 +562,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsForScale() throws IOExceptio public void testDeleteDuplicateTasks() throws IOException { ADTask adTask = randomAdTask(); - adTaskManager.handleADTaskException(adTask, new DuplicateTaskException("test")); + adTaskManager.handleTaskException(adTask, new DuplicateTaskException("test")); verify(client, times(1)).delete(any(), any()); } @@ -595,7 +601,7 @@ public void testDetectorTaskSlotScaleUpDelta() { DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; // Scale down - when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(hashRing.getNodesWithSameLocalVersion()).thenReturn(eligibleDataNodes); when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); int taskSlots = maxRunningEntities - 1; when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); @@ -608,7 +614,7 @@ public void testDetectorTaskSlotScaleDownDelta() { DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; // Scale down - when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(hashRing.getNodesWithSameLocalVersion()).thenReturn(eligibleDataNodes); when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); int taskSlots = maxRunningEntities * 5; when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); @@ -727,7 +733,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNode() { ActionListener listener = invocation.getArgument(3); listener.onResponse(new UpdateResponse(ShardId.fromString("[test][1]"), "1", 0L, 1L, 1L, DocWriteResponse.Result.UPDATED)); return null; - }).when(adTaskManager).updateLatestADTask(anyString(), any(), anyMap(), any()); + }).when(adTaskManager).updateLatestTask(anyString(), any(), anyMap(), any()); adTaskManager .updateLatestRealtimeTaskOnCoordinatingNode( detectorId, @@ -775,7 +781,7 @@ public void testGetLocalADTaskProfilesByDetectorId() { @SuppressWarnings("unchecked") public void testRemoveStaleRunningEntity() throws IOException { - ActionListener actionListener = mock(ActionListener.class); + ActionListener actionListener = mock(ActionListener.class); ADTask adTask = randomAdTask(); String entity = randomAlphaOfLength(5); ExecutorService executeService = mock(ExecutorService.class); @@ -825,7 +831,7 @@ public void testResetLatestFlagAsFalse() throws IOException { public void testCleanADResultOfDeletedDetectorWithNoDeletedDetector() { when(adTaskCacheManager.pollDeletedConfig()).thenReturn(null); - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); verify(client, never()).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); } @@ -874,57 +880,59 @@ public void testCleanADResultOfDeletedDetectorWithException() { nodeFilter, hashRing, adTaskCacheManager, - threadPool + threadPool, + nodeStateManager, + taskProfileRunner ) ); - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); verify(client, times(1)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).addDeletedConfig(eq(detectorId)); - adTaskManager.cleanADResultOfDeletedDetector(); + adTaskManager.cleanResultOfDeletedConfig(); verify(client, times(2)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).addDeletedConfig(eq(detectorId)); } public void testMaintainRunningHistoricalTasksWithOwningNodeIsNotLocalNode() { // Test no owning node - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.empty()); adTaskManager.maintainRunningHistoricalTasks(transportService, 10); verify(client, never()).search(any(), any()); // Test owning node is not local node - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node2)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node2)); doReturn(node1).when(clusterService).localNode(); adTaskManager.maintainRunningHistoricalTasks(transportService, 10); verify(client, never()).search(any(), any()); } public void testMaintainRunningHistoricalTasksWithNoRunningTask() { - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node1)); doReturn(node1).when(clusterService).localNode(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); InternalSearchResponse response = new InternalSearchResponse( - searchHits, - InternalAggregations.EMPTY, - null, - null, - false, - null, - 1 - ); + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); SearchResponse searchResponse = new SearchResponse( - response, - null, - 1, - 1, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); @@ -933,7 +941,7 @@ public void testMaintainRunningHistoricalTasksWithNoRunningTask() { } public void testMaintainRunningHistoricalTasksWithRunningTask() { - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node1)); doReturn(node1).when(clusterService).localNode(); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); @@ -946,24 +954,24 @@ public void testMaintainRunningHistoricalTasksWithRunningTask() { SearchHit task = SearchHit.fromXContent(TestHelpers.parser(runningHistoricalHCTaskContent)); SearchHits searchHits = new SearchHits(new SearchHit[] { task }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); InternalSearchResponse response = new InternalSearchResponse( - searchHits, - InternalAggregations.EMPTY, - null, - null, - false, - null, - 1 - ); + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); SearchResponse searchResponse = new SearchResponse( - response, - null, - 1, - 1, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); @@ -972,11 +980,11 @@ public void testMaintainRunningHistoricalTasksWithRunningTask() { } public void testMaintainRunningRealtimeTasksWithNoRealtimeTask() { - when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(null); + when(adTaskCacheManager.getConfigIdsInRealtimeTaskCache()).thenReturn(null); adTaskManager.maintainRunningRealtimeTasks(); verify(adTaskCacheManager, never()).removeRealtimeTaskCache(anyString()); - when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(new String[0]); + when(adTaskCacheManager.getConfigIdsInRealtimeTaskCache()).thenReturn(new String[0]); adTaskManager.maintainRunningRealtimeTasks(); verify(adTaskCacheManager, never()).removeRealtimeTaskCache(anyString()); } @@ -985,7 +993,7 @@ public void testMaintainRunningRealtimeTasks() { String detectorId1 = randomAlphaOfLength(5); String detectorId2 = randomAlphaOfLength(5); String detectorId3 = randomAlphaOfLength(5); - when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(new String[] { detectorId1, detectorId2, detectorId3 }); + when(adTaskCacheManager.getConfigIdsInRealtimeTaskCache()).thenReturn(new String[] { detectorId1, detectorId2, detectorId3 }); when(adTaskCacheManager.getRealtimeTaskCache(detectorId1)).thenReturn(null); RealtimeTaskCache cacheOfDetector2 = mock(RealtimeTaskCache.class); @@ -1006,12 +1014,12 @@ public void testStartHistoricalAnalysisWithNoOwningNode() throws IOException { DateRange detectionDateRange = TestHelpers.randomDetectionDateRange(); User user = null; int availableTaskSlots = randomIntBetween(1, 10); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer> function = invocation.getArgument(1); function.accept(Optional.empty()); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(anyString(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(anyString(), any(), any()); adTaskManager.startHistoricalAnalysis(detector, detectionDateRange, user, availableTaskSlots, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1068,7 +1076,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningRealtimeTaskWithTaskStopp ); setupGetAndExecuteOnLatestADTasks(profile); adTaskManager - .getAndExecuteOnLatestADTasks( + .getAndExecuteOnLatestTasks( detectorId, null, null, @@ -1134,7 +1142,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningHistoricalTask() throws I ); setupGetAndExecuteOnLatestADTasks(profile); adTaskManager - .getAndExecuteOnLatestADTasks( + .getAndExecuteOnLatestTasks( detectorId, null, null, @@ -1187,13 +1195,13 @@ private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) { }).when(client).search(any(), any()); String detectorId = randomAlphaOfLength(5); Consumer> function = mock(Consumer.class); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer getNodeFunction = invocation.getArgument(0); getNodeFunction.accept(new DiscoveryNode[] { node1, node2 }); return null; - }).when(hashRing).getAllEligibleDataNodesWithKnownAdVersion(any(), any()); + }).when(hashRing).getAllEligibleDataNodesWithKnownVersion(any(), any()); doAnswer(invocation -> { ActionListener taskProfileResponseListener = invocation.getArgument(2); @@ -1248,7 +1256,8 @@ private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) { Instant.now(), 60L, TestHelpers.randomUser(), - null + null, + AnalysisType.AD ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) ), Collections.emptyMap(), @@ -1267,7 +1276,7 @@ public void testCreateADTaskDirectlyWithException() throws IOException { ActionListener listener = mock(ActionListener.class); doThrow(new RuntimeException("test")).when(client).index(any(), any()); - adTaskManager.createADTaskDirectly(adTask, function, listener); + adTaskManager.createTaskDirectly(adTask, function, listener); verify(listener, times(1)).onFailure(any()); doAnswer(invocation -> { @@ -1275,13 +1284,13 @@ public void testCreateADTaskDirectlyWithException() throws IOException { actionListener.onFailure(new RuntimeException("test")); return null; }).when(client).index(any(), any()); - adTaskManager.createADTaskDirectly(adTask, function, listener); + adTaskManager.createTaskDirectly(adTask, function, listener); verify(listener, times(2)).onFailure(any()); } public void testCleanChildTasksAndADResultsOfDeletedTaskWithNoDeletedDetectorTask() { when(adTaskCacheManager.hasDeletedTask()).thenReturn(false); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, never()).execute(any(), any(), any()); } @@ -1300,7 +1309,7 @@ public void testCleanChildTasksAndADResultsOfDeletedTaskWithNullTask() { return null; }).when(threadPool).schedule(any(), any(), any()); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, never()).execute(any(), any(), any()); } @@ -1319,7 +1328,7 @@ public void testCleanChildTasksAndADResultsOfDeletedTaskWithFailToDeleteADResult return null; }).when(threadPool).schedule(any(), any(), any()); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, times(1)).execute(any(), any(), any()); } @@ -1339,7 +1348,7 @@ public void testCleanChildTasksAndADResultsOfDeletedTask() { return null; }).when(threadPool).schedule(any(), any(), any()); - adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + adTaskManager.cleanChildTasksAndResultsOfDeletedTask(); verify(client, times(2)).execute(any(), any(), any()); } @@ -1356,7 +1365,7 @@ public void testDeleteADTasks() { String detectorId = randomAlphaOfLength(5); ExecutorFunction function = mock(ExecutorFunction.class); ActionListener listener = mock(ActionListener.class); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(function, times(1)).execute(); } @@ -1381,7 +1390,7 @@ public void testDeleteADTasksWithBulkFailures() { String detectorId = randomAlphaOfLength(5); ExecutorFunction function = mock(ExecutorFunction.class); ActionListener listener = mock(ActionListener.class); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(listener, times(1)).onFailure(any()); } @@ -1401,11 +1410,11 @@ public void testDeleteADTasksWithException() { ExecutorFunction function = mock(ExecutorFunction.class); ActionListener listener = mock(ActionListener.class); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(function, times(1)).execute(); verify(listener, never()).onFailure(any()); - adTaskManager.deleteADTasks(detectorId, function, listener); + adTaskManager.deleteTasks(detectorId, function, listener); verify(function, times(1)).execute(); verify(listener, times(1)).onFailure(any()); } @@ -1413,7 +1422,7 @@ public void testDeleteADTasksWithException() { @SuppressWarnings("unchecked") public void testScaleUpTaskSlots() throws IOException { ADTask adTask = randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); when(adTaskCacheManager.getAvailableNewEntityTaskLanes(anyString())).thenReturn(0); doReturn(2).when(adTaskManager).detectorTaskSlotScaleDelta(anyString()); when(adTaskCacheManager.getLastScaleEntityTaskLaneTime(anyString())).thenReturn(null); @@ -1433,12 +1442,12 @@ public void testScaleUpTaskSlots() throws IOException { public void testForwardRequestToLeadNodeWithNotExistingNode() throws IOException { ADTask adTask = randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); ForwardADTaskRequest forwardADTaskRequest = new ForwardADTaskRequest(adTask, ADTaskAction.APPLY_FOR_TASK_SLOTS); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer> function = invocation.getArgument(1); function.accept(Optional.empty()); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(any(), any(), any()); adTaskManager.forwardRequestToLeadNode(forwardADTaskRequest, transportService, listener); verify(listener, times(1)).onFailure(any()); @@ -1449,19 +1458,19 @@ public void testScaleTaskLaneOnCoordinatingNode() { ADTask adTask = mock(ADTask.class); when(adTask.getCoordinatingNode()).thenReturn(node1.getId()); when(nodeFilter.getEligibleDataNodes()).thenReturn(new DiscoveryNode[] { node1, node2 }); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); adTaskManager.scaleTaskLaneOnCoordinatingNode(adTask, 2, transportService, listener); } @SuppressWarnings("unchecked") - public void testStartDetectorWithException() throws IOException { + public void testgetAndExecuteOnLatestConfigLevelTaskWithException() throws IOException { AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); DateRange detectionDateRange = randomDetectionDateRange(); User user = null; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); when(detectionIndices.doesStateIndexExist()).thenReturn(false); doThrow(new RuntimeException("test")).when(detectionIndices).initStateIndex(any()); - adTaskManager.startDetector(detector, detectionDateRange, user, transportService, listener); + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detector, detectionDateRange, false, user, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1469,13 +1478,13 @@ public void testStartDetectorWithException() throws IOException { public void testStopDetectorWithNonExistingDetector() { String detectorId = randomAlphaOfLength(5); boolean historical = true; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { - Consumer> function = invocation.getArgument(1); + Consumer> function = invocation.getArgument(2); function.accept(Optional.empty()); return null; - }).when(adTaskManager).getDetector(anyString(), any(), any()); - adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + }).when(nodeStateManager).getConfig(anyString(), eq(AnalysisType.AD), any(Consumer.class), any()); + indexAnomalyDetectorJobActionHandler.stopConfig(detectorId, historical, null, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1483,13 +1492,13 @@ public void testStopDetectorWithNonExistingDetector() { public void testStopDetectorWithNonExistingTask() { String detectorId = randomAlphaOfLength(5); boolean historical = true; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { - Consumer> function = invocation.getArgument(1); + Consumer> function = invocation.getArgument(2); AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); function.accept(Optional.of(detector)); return null; - }).when(adTaskManager).getDetector(anyString(), any(), any()); + }).when(nodeStateManager).getConfig(anyString(), eq(AnalysisType.AD), any(Consumer.class), any()); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -1497,7 +1506,7 @@ public void testStopDetectorWithNonExistingTask() { return null; }).when(client).search(any(), any()); - adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + indexAnomalyDetectorJobActionHandler.stopConfig(detectorId, historical, null, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1505,13 +1514,13 @@ public void testStopDetectorWithNonExistingTask() { public void testStopDetectorWithTaskDone() { String detectorId = randomAlphaOfLength(5); boolean historical = true; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { - Consumer> function = invocation.getArgument(1); + Consumer> function = invocation.getArgument(2); AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); function.accept(Optional.of(detector)); return null; - }).when(adTaskManager).getDetector(anyString(), any(), any()); + }).when(nodeStateManager).getConfig(anyString(), eq(AnalysisType.AD), any(Consumer.class), any()); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -1540,14 +1549,14 @@ public void testStopDetectorWithTaskDone() { return null; }).when(client).search(any(), any()); - adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + indexAnomalyDetectorJobActionHandler.stopConfig(detectorId, historical, null, transportService, listener); verify(listener, times(1)).onFailure(any()); } @SuppressWarnings("unchecked") public void testGetDetectorWithWrongContent() { String detectorId = randomAlphaOfLength(5); - Consumer> function = mock(Consumer.class); + Consumer> function = mock(Consumer.class); ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -1571,7 +1580,18 @@ public void testGetDetectorWithWrongContent() { responseListener.onResponse(response); return null; }).when(client).get(any(), any()); - adTaskManager.getDetector(detectorId, function, listener); + NodeStateManager nodeStateManager = new NodeStateManager( + client, + xContentRegistry(), + Settings.EMPTY, + mock(ClientUtil.class), + mock(Clock.class), + TimeSeriesSettings.HOURLY_MAINTENANCE, + clusterService, + TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + TimeSeriesSettings.BACKOFF_MINUTES + ); + nodeStateManager.getConfig(detectorId, AnalysisType.AD, function, listener); verify(listener, times(1)).onFailure(any()); } diff --git a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java index 8ce30df12..60f68131e 100644 --- a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java @@ -34,6 +34,7 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TimeSeriesTask; import org.opensearch.timeseries.util.ExceptionUtil; import com.google.common.collect.ImmutableList; @@ -102,7 +103,7 @@ public void testHistoricalAnalysisWithValidDateRange() throws IOException, Inter client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(20000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(TimeSeriesTask.STATE_FIELD))); } public void testHistoricalAnalysisWithNonExistingIndex() throws IOException { @@ -140,7 +141,7 @@ public void testDisableADPlugin() throws IOException { ImmutableList.of(NotSerializableExceptionWrapper.class, EndRunException.class), () -> client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(10000) ); - assertTrue(exception.getMessage(), exception.getMessage().contains("AD functionality is disabled")); + assertTrue(exception.getMessage(), exception.getMessage().contains("AD plugin is disabled")); updateTransientSettings(ImmutableMap.of(AD_ENABLED, false)); } finally { // guarantee reset back to default @@ -162,7 +163,7 @@ public void testMultipleTasks() throws IOException, InterruptedException { client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(25000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(TimeSeriesTask.STATE_FIELD))); updateTransientSettings(ImmutableMap.of(MAX_BATCH_TASK_PER_NODE.getKey(), 1)); } @@ -187,6 +188,6 @@ private void testInvalidDetectionDateRange(DateRange dateRange, String error) th client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(5000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertEquals(error, doc.getSourceAsMap().get(ADTask.ERROR_FIELD)); + assertEquals(error, doc.getSourceAsMap().get(TimeSeriesTask.ERROR_FIELD)); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java index 6946953fc..076c8763c 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java @@ -20,16 +20,17 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.ResultBulkResponse; public class ADResultBulkResponseTests extends OpenSearchTestCase { public void testSerialization() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); List retryRequests = new ArrayList<>(); retryRequests.add(new IndexRequest("index").id("blah").source(Collections.singletonMap("foo", "bar"))); - ADResultBulkResponse response = new ADResultBulkResponse(retryRequests); + ResultBulkResponse response = new ResultBulkResponse(retryRequests); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADResultBulkResponse readResponse = new ADResultBulkResponse(streamInput); + ResultBulkResponse readResponse = new ResultBulkResponse(streamInput); assertTrue(readResponse.hasFailures()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java index da8f3dce7..b545b97d8 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java @@ -18,6 +18,8 @@ import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsRequest; public class ADStatsITTests extends OpenSearchIntegTestCase { @@ -31,9 +33,9 @@ protected Collection> transportClientPlugins() { } public void testNormalADStats() throws ExecutionException, InterruptedException { - ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); + StatsRequest adStatsRequest = new StatsRequest(new String[0]); - ADStatsNodesResponse response = client().execute(ADStatsNodesAction.INSTANCE, adStatsRequest).get(); + StatsNodesResponse response = client().execute(ADStatsNodesAction.INSTANCE, adStatsRequest).get(); assertTrue("getting stats failed", !response.hasFailures()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java index 4836825f3..abd8c5d8c 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.time.Clock; import java.time.Instant; +import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.TreeMap; import java.util.stream.Collectors; @@ -34,8 +36,6 @@ import org.opensearch.Version; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -45,9 +45,16 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.StatsNodeRequest; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsRequest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -78,18 +85,18 @@ public void setUp() throws Exception { @Test public void testADStatsNodeRequest() throws IOException { - ADStatsNodeRequest adStatsNodeRequest1 = new ADStatsNodeRequest(); + StatsNodeRequest adStatsNodeRequest1 = new StatsNodeRequest(); assertNull("ADStatsNodeRequest default constructor failed", adStatsNodeRequest1.getADStatsRequest()); - ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); - ADStatsNodeRequest adStatsNodeRequest2 = new ADStatsNodeRequest(adStatsRequest); + StatsRequest adStatsRequest = new StatsRequest(new String[0]); + StatsNodeRequest adStatsNodeRequest2 = new StatsNodeRequest(adStatsRequest); assertEquals("ADStatsNodeRequest has the wrong ADStatsRequest", adStatsNodeRequest2.getADStatsRequest(), adStatsRequest); // Test serialization BytesStreamOutput output = new BytesStreamOutput(); adStatsNodeRequest2.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - adStatsNodeRequest1 = new ADStatsNodeRequest(streamInput); + adStatsNodeRequest1 = new StatsNodeRequest(streamInput); assertEquals( "readStats failed", adStatsNodeRequest2.getADStatsRequest().getStatsToBeRetrieved(), @@ -106,11 +113,11 @@ public void testSimpleADStatsNodeResponse() throws IOException, JsonPathNotFound }; // Test serialization - ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, stats); + StatsNodeResponse adStatsNodeResponse = new StatsNodeResponse(discoveryNode1, stats); BytesStreamOutput output = new BytesStreamOutput(); adStatsNodeResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsNodeResponse readResponse = ADStatsNodeResponse.readStats(streamInput); + StatsNodeResponse readResponse = StatsNodeResponse.readStats(streamInput); assertEquals("readStats failed", readResponse.getStatsMap(), adStatsNodeResponse.getStatsMap()); // Test toXContent @@ -139,25 +146,27 @@ public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotF attributes.put(name2, val2); String detectorId = "detectorId"; Entity entity = Entity.createEntityFromOrderedMap(attributes); - EntityModel entityModel = new EntityModel(entity, null, null); Clock clock = mock(Clock.class); when(clock.instant()).thenReturn(Instant.now()); - ModelState state = new ModelState( - entityModel, + ModelState state = new ModelState( + null, entity.getModelId(detectorId).get(), detectorId, - "entity", + ModelManager.ModelType.TRCF.getName(), clock, - 0.1f + 0.1f, + null, + Optional.empty(), + new ArrayDeque<>() ); Map stats = state.getModelStateAsMap(); // Test serialization - ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, stats); + StatsNodeResponse adStatsNodeResponse = new StatsNodeResponse(discoveryNode1, stats); BytesStreamOutput output = new BytesStreamOutput(); adStatsNodeResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsNodeResponse readResponse = ADStatsNodeResponse.readStats(streamInput); + StatsNodeResponse readResponse = StatsNodeResponse.readStats(streamInput); assertEquals("readStats failed", readResponse.getStatsMap(), adStatsNodeResponse.getStatsMap()); // Test toXContent @@ -192,7 +201,7 @@ public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotF @Test public void testADStatsRequest() throws IOException { List allStats = Arrays.stream(StatNames.values()).map(StatNames::getName).collect(Collectors.toList()); - ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); + StatsRequest adStatsRequest = new StatsRequest(new String[0]); // Test clear() adStatsRequest.clear(); @@ -215,7 +224,7 @@ public void testADStatsRequest() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); adStatsRequest.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsRequest readRequest = new ADStatsRequest(streamInput); + StatsRequest readRequest = new StatsRequest(streamInput); assertEquals("Serialization fails", readRequest.getStatsToBeRetrieved(), adStatsRequest.getStatsToBeRetrieved()); } @@ -227,10 +236,10 @@ public void testADStatsNodesResponse() throws IOException, JsonPathNotFoundExcep } }; - ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, nodeStats); - List adStatsNodeResponses = Collections.singletonList(adStatsNodeResponse); + StatsNodeResponse adStatsNodeResponse = new StatsNodeResponse(discoveryNode1, nodeStats); + List adStatsNodeResponses = Collections.singletonList(adStatsNodeResponse); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(new ClusterName(clusterName), adStatsNodeResponses, failures); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(new ClusterName(clusterName), adStatsNodeResponses, failures); // Test toXContent XContentBuilder builder = jsonBuilder(); @@ -256,7 +265,7 @@ public void testADStatsNodesResponse() throws IOException, JsonPathNotFoundExcep adStatsNodesResponse.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADStatsNodesResponse readRequest = new ADStatsNodesResponse(streamInput); + StatsNodesResponse readRequest = new StatsNodesResponse(streamInput); builder = jsonBuilder(); String readJson = readRequest.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject().toString(); diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java index d8e13e9ec..23ae4b954 100644 --- a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java @@ -33,6 +33,7 @@ import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; @@ -117,7 +118,7 @@ private void testADTaskProfileResponse(ADTaskProfileNodeResponse response) throw } public void testADTaskProfileParse() throws IOException { - ADTaskProfile adTaskProfile = new ADTaskProfile( + TaskProfile adTaskProfile = new ADTaskProfile( randomAlphaOfLength(5), randomInt(), randomLong(), @@ -128,7 +129,7 @@ public void testADTaskProfileParse() throws IOException { ); String adTaskProfileString = TestHelpers .xContentBuilderToString(adTaskProfile.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADTaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); + TaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); assertEquals(adTaskProfile, parsedADTaskProfile); assertEquals(parsedADTaskProfile.toString(), adTaskProfile.toString()); } @@ -170,7 +171,7 @@ public void testSerializeResponse() throws IOException { } public void testADTaskProfileParseFullConstructor() throws IOException { - ADTaskProfile adTaskProfile = new ADTaskProfile( + TaskProfile adTaskProfile = new ADTaskProfile( TestHelpers.randomAdTask(), randomInt(), randomLong(), @@ -190,7 +191,7 @@ public void testADTaskProfileParseFullConstructor() throws IOException { ); String adTaskProfileString = TestHelpers .xContentBuilderToString(adTaskProfile.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - ADTaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); + TaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); assertEquals(adTaskProfile, parsedADTaskProfile); } } diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 1b23b6d51..08dd23587 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -65,20 +65,15 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -108,6 +103,7 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; @@ -115,9 +111,16 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.transport.ResultResponse; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.NodeNotConnectedException; import org.opensearch.transport.RemoteTransportException; @@ -139,7 +142,7 @@ public class AnomalyResultTests extends AbstractTimeSeriesTest { private ClusterService clusterService; private NodeStateManager stateManager; private FeatureManager featureQuery; - private ModelManager normalModelManager; + private ADModelManager normalModelManager; private Client client; private SecurityClientUtil clientUtil; private AnomalyDetector detector; @@ -203,7 +206,7 @@ public void setUp() throws Exception { hashRing = mock(HashRing.class); Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(localNode); doReturn(localNode).when(hashRing).getNodeByAddress(any()); featureQuery = mock(FeatureManager.class); @@ -216,7 +219,7 @@ public void setUp() throws Exception { double rcfScore = 0.2; confidence = 0.91; anomalyGrade = 0.5; - normalModelManager = mock(ModelManager.class); + normalModelManager = mock(ADModelManager.class); long totalUpdates = 1440; int relativeIndex = 0; double[] currentTimeAttribution = new double[] { 0.5, 0.5 }; @@ -288,12 +291,12 @@ public void setUp() throws Exception { indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; @@ -309,7 +312,6 @@ public void setUp() throws Exception { DetectorInternalState.Builder result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now()); listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getId(), ADCommonName.DETECTION_STATE_INDEX)); - } return null; @@ -466,7 +468,7 @@ public void sendRequest( // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(discoveryNode); when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handler on testNodes[1] new RCFResultTransportAction( @@ -515,7 +517,7 @@ public void noModelExceptionTemplate(Exception exception, String adID, String er @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringColdStart() { - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doThrow(ResourceNotFoundException.class) .when(rcfManager) .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); @@ -563,7 +565,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringRestoringModel() { - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doThrow(new NotSerializableExceptionWrapper(new LimitExceededException(adID, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))) .when(rcfManager) .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); @@ -689,7 +691,7 @@ public void sendRequest( // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(discoveryNode); when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handlers on testNodes[1] ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); @@ -798,7 +800,7 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, .when(exceptionTransportService) .getConnection(same(rcfNode)); } else { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(thresholdModelID))).thenReturn(Optional.of(thresholdNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(thresholdModelID))).thenReturn(Optional.of(thresholdNode)); when(hashRing.getNodeByAddress(any())).thenReturn(Optional.of(thresholdNode)); doThrow(new NodeNotConnectedException(rcfNode, "rcf node not connected")) .when(exceptionTransportService) @@ -845,10 +847,10 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, assertException(listener, TimeSeriesException.class); if (!temporary) { - verify(hashRing, times(numberOfBuildCall)).buildCirclesForRealtimeAD(); + verify(hashRing, times(numberOfBuildCall)).buildCirclesForRealtime(); verify(stateManager, never()).addPressure(any(String.class), any(String.class)); } else { - verify(hashRing, never()).buildCirclesForRealtimeAD(); + verify(hashRing, never()).buildCirclesForRealtime(); verify(stateManager, times(numberOfBuildCall)).addPressure(any(String.class), any(String.class)); } } @@ -895,7 +897,7 @@ public void testMute() { action.doExecute(null, request, listener); Throwable exception = assertException(listener, TimeSeriesException.class); - assertThat(exception.getMessage(), containsString(AnomalyResultTransportAction.NODE_UNRESPONSIVE_ERR_MSG)); + assertThat(exception.getMessage(), containsString(ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG)); } public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOException { @@ -910,7 +912,7 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE ); Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(localNode); doReturn(localNode).when(hashRing).getNodeByAddress(any()); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -971,7 +973,7 @@ public String executor() { } public void testSerialzationResponse() throws IOException { - AnomalyResultResponse response = new AnomalyResultResponse( + ResultResponse response = new AnomalyResultResponse( 4d, 0.993, 1.01, @@ -985,7 +987,8 @@ public void testSerialzationResponse() throws IOException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); @@ -996,7 +999,7 @@ public void testSerialzationResponse() throws IOException { } public void testJsonResponse() throws IOException, JsonPathNotFoundException { - AnomalyResultResponse response = new AnomalyResultResponse( + ResultResponse response = new AnomalyResultResponse( 4d, 0.993, 1.01, @@ -1010,7 +1013,8 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); XContentBuilder builder = jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -1042,7 +1046,8 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); assertAnomalyResultResponse(readResponse, readResponse.getAnomalyGrade(), readResponse.getConfidence(), 0d); } @@ -1054,7 +1059,7 @@ public void testSerialzationRequest() throws IOException { StreamInput streamInput = output.bytes().streamInput(); AnomalyResultRequest readRequest = new AnomalyResultRequest(streamInput); - assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + assertThat(request.getConfigId(), equalTo(readRequest.getConfigId())); assertThat(request.getStart(), equalTo(readRequest.getStart())); assertThat(request.getEnd(), equalTo(readRequest.getEnd())); } @@ -1065,7 +1070,7 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { request.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = builder.toString(); - assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getAdID()); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getConfigId()); assertEquals(JsonDeserializer.getLongValue(json, CommonName.START_JSON_KEY), request.getStart()); assertEquals(JsonDeserializer.getLongValue(json, CommonName.END_JSON_KEY), request.getEnd()); } @@ -1093,25 +1098,28 @@ public void testNegativeTime() { // no exception should be thrown @SuppressWarnings("unchecked") public void testOnFailureNull() throws IOException { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, + ADResultProcessor adResultProcessor = new ADResultProcessor( + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityADResultAction.NAME, + StatNames.AD_HC_EXECUTE_REQUEST_COUNT, settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, threadPool, + hashRing, + stateManager, + transportService, + adStats, + adTaskManager, NamedXContentRegistry.EMPTY, - adTaskManager + client, + clientUtil, + indexNameResolver, + AnomalyResultResponse.class, + featureQuery, + normalModelManager ); - AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( + ADResultProcessor.RCFActionListener listener = adResultProcessor.new RCFActionListener( null, null, null, null, mock(ActionListener.class), null, null ); listener.onFailure(null); @@ -1482,17 +1490,17 @@ private void globalBlockTemplate(BlockType type, String errLogMsg) { } public void testReadBlock() { - globalBlockTemplate(BlockType.GLOBAL_BLOCK_READ, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + globalBlockTemplate(BlockType.GLOBAL_BLOCK_READ, ResultProcessor.READ_WRITE_BLOCKED); } public void testWriteBlock() { - globalBlockTemplate(BlockType.GLOBAL_BLOCK_WRITE, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + globalBlockTemplate(BlockType.GLOBAL_BLOCK_WRITE, ResultProcessor.READ_WRITE_BLOCKED); } public void testIndexReadBlock() { globalBlockTemplate( BlockType.INDEX_BLOCK, - AnomalyResultTransportAction.INDEX_READ_BLOCKED, + ResultProcessor.INDEX_READ_BLOCKED, Settings.builder().put(IndexMetadata.INDEX_BLOCKS_READ_SETTING.getKey(), true).build(), "test1" ); @@ -1500,53 +1508,59 @@ public void testIndexReadBlock() { @SuppressWarnings("unchecked") public void testNullRCFResult() { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, + ADResultProcessor adResultProcessor = new ADResultProcessor( + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityADResultAction.NAME, + StatNames.AD_HC_EXECUTE_REQUEST_COUNT, settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, threadPool, + hashRing, + stateManager, + transportService, + adStats, + adTaskManager, NamedXContentRegistry.EMPTY, - adTaskManager + client, + clientUtil, + indexNameResolver, + AnomalyResultResponse.class, + featureQuery, + normalModelManager ); - AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( + ADResultProcessor.RCFActionListener listener = adResultProcessor.new RCFActionListener( "123-rcf-0", null, "123", null, mock(ActionListener.class), null, null ); listener.onResponse(null); - assertTrue(testAppender.containsMessage(AnomalyResultTransportAction.NULL_RESPONSE)); + assertTrue(testAppender.containsMessage(ResultProcessor.NULL_RESPONSE)); } @SuppressWarnings("unchecked") public void testNormalRCFResult() { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, + ADResultProcessor adResultProcessor = new ADResultProcessor( + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityADResultAction.NAME, + StatNames.AD_HC_EXECUTE_REQUEST_COUNT, settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, threadPool, + hashRing, + stateManager, + transportService, + adStats, + adTaskManager, NamedXContentRegistry.EMPTY, - adTaskManager + client, + clientUtil, + indexNameResolver, + AnomalyResultResponse.class, + featureQuery, + normalModelManager ); ActionListener listener = mock(ActionListener.class); - AnomalyResultTransportAction.RCFActionListener rcfListener = action.new RCFActionListener( + ADResultProcessor.RCFActionListener rcfListener = adResultProcessor.new RCFActionListener( "123-rcf-0", null, "nodeID", detector, listener, null, adID ); double[] attribution = new double[] { 1. }; @@ -1561,27 +1575,30 @@ public void testNormalRCFResult() { @SuppressWarnings("unchecked") public void testNullPointerRCFResult() { - AnomalyResultTransportAction action = new AnomalyResultTransportAction( - new ActionFilters(Collections.emptySet()), - transportService, + ADResultProcessor adResultProcessor = new ADResultProcessor( + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityADResultAction.NAME, + StatNames.AD_HC_EXECUTE_REQUEST_COUNT, settings, - client, - clientUtil, - stateManager, - featureQuery, - normalModelManager, - hashRing, clusterService, - indexNameResolver, - adCircuitBreakerService, - adStats, threadPool, + hashRing, + stateManager, + transportService, + adStats, + adTaskManager, NamedXContentRegistry.EMPTY, - adTaskManager + client, + clientUtil, + indexNameResolver, + AnomalyResultResponse.class, + featureQuery, + normalModelManager ); ActionListener listener = mock(ActionListener.class); // detector being null causes NullPointerException - AnomalyResultTransportAction.RCFActionListener rcfListener = action.new RCFActionListener( + ADResultProcessor.RCFActionListener rcfListener = adResultProcessor.new RCFActionListener( "123-rcf-0", null, "nodeID", null, listener, null, adID ); double[] attribution = new double[] { 1. }; @@ -1634,7 +1651,7 @@ public void testEndRunDueToNoTrainingData() { ThreadPool mockThreadPool = mock(ThreadPool.class); setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doAnswer(invocation -> { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[3]; diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java index 78ffca8dd..0df346f14 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java @@ -220,7 +220,8 @@ private AnomalyDetector randomDetector(List indices, List featu null, null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -243,7 +244,8 @@ private AnomalyDetector randomHCDetector(List indices, List fea ImmutableList.of(categoryField), null, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } diff --git a/src/test/java/org/opensearch/ad/transport/DelegateADProfileTransportAction.java b/src/test/java/org/opensearch/ad/transport/DelegateADProfileTransportAction.java new file mode 100644 index 000000000..410fe3f9a --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DelegateADProfileTransportAction.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; +import org.opensearch.transport.TransportService; + +/** + * This utility class serves as a delegate for testing ProfileTransportAction functionalities. + * It facilitates the invocation of protected methods within the org.opensearch.ad.transport.ADProfileTransportAction + * and org.opensearch.timeseries.transport.BaseProfileTransportAction classes, which are otherwise inaccessible + * due to Java's access control restrictions. This is achieved by extending the target classes or using reflection + * where inheritance is not possible, enabling the testing framework to perform comprehensive tests on protected + * class members across different packages. + */ +public class DelegateADProfileTransportAction extends ADProfileTransportAction { + + public DelegateADProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cacheProvider, + Settings settings + ) { + super(threadPool, clusterService, transportService, actionFilters, modelManager, featureManager, cacheProvider, settings); + } + + @Override + public ProfileResponse newResponse(ProfileRequest request, List responses, List failures) { + return super.newResponse(request, responses, failures); + } + + @Override + public ProfileNodeRequest newNodeRequest(ProfileRequest request) { + return super.newNodeRequest(request); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DelegateDeleteADModelTransportAction.java b/src/test/java/org/opensearch/ad/transport/DelegateDeleteADModelTransportAction.java new file mode 100644 index 000000000..b32d81978 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DelegateDeleteADModelTransportAction.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.transport.TransportService; + +/** + * This utility class serves as a delegate for testing ProfileTransportAction functionalities. + * It facilitates the invocation of protected methods within the org.opensearch.ad.transport.DeleteADModelTransportAction + * and org.opensearch.timeseries.transport.BaseDeleteModelTransportAction classes, which are otherwise inaccessible + * due to Java's access control restrictions. This is achieved by extending the target classes or using reflection + * where inheritance is not possible, enabling the testing framework to perform comprehensive tests on protected + * class members across different packages. + */ +public class DelegateDeleteADModelTransportAction extends DeleteADModelTransportAction { + public DelegateDeleteADModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cache, + ADTaskCacheManager adTaskCacheManager, + ADEntityColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + modelManager, + featureManager, + cache, + adTaskCacheManager, + coldStarter + ); + } + + @Override + public DeleteModelResponse newResponse( + DeleteModelRequest request, + List responses, + List failures + ) { + return super.newResponse(request, responses, failures); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java index ac81ecf25..c388e7499 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java @@ -22,6 +22,7 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.transport.DeleteConfigRequest; import com.google.common.collect.ImmutableList; @@ -60,7 +61,7 @@ public void testDeleteAnomalyDetectorWithEnabledFeature() throws IOException { private void testDeleteDetector(AnomalyDetector detector) throws IOException { String detectorId = createDetector(detector); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest(detectorId); + DeleteConfigRequest request = new DeleteConfigRequest(detectorId); DeleteResponse deleteResponse = client().execute(DeleteAnomalyDetectorAction.INSTANCE, request).actionGet(10000); assertEquals("deleted", deleteResponse.getResult().getLowercase()); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java index 1a57504cc..52678e63f 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java @@ -20,6 +20,10 @@ import org.opensearch.common.action.ActionFuture; import org.opensearch.plugins.Plugin; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; public class DeleteITTests extends ADIntegTestCase { @@ -28,23 +32,24 @@ protected Collection> nodePlugins() { return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } + @Override protected Collection> transportClientPlugins() { return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testNormalStopDetector() throws ExecutionException, InterruptedException { - StopDetectorRequest request = new StopDetectorRequest().adID("123"); + StopConfigRequest request = new StopConfigRequest().adID("123"); - ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); - StopDetectorResponse response = future.get(); + StopConfigResponse response = future.get(); assertTrue(response.success()); } public void testNormalDeleteModel() throws ExecutionException, InterruptedException { DeleteModelRequest request = new DeleteModelRequest("123"); - ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + ActionFuture future = client().execute(DeleteADModelAction.INSTANCE, request); DeleteModelResponse response = future.get(); assertTrue(!response.hasFailures()); @@ -53,15 +58,15 @@ public void testNormalDeleteModel() throws ExecutionException, InterruptedExcept public void testEmptyIDDeleteModel() throws ExecutionException, InterruptedException { DeleteModelRequest request = new DeleteModelRequest(""); - ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + ActionFuture future = client().execute(DeleteADModelAction.INSTANCE, request); expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); } public void testEmptyIDStopDetector() throws ExecutionException, InterruptedException { - StopDetectorRequest request = new StopDetectorRequest(); + StopConfigRequest request = new StopConfigRequest(); - ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java index b76925492..46e8af186 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java @@ -27,13 +27,12 @@ import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -46,6 +45,13 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.CronNodeResponse; +import org.opensearch.timeseries.transport.CronResponse; +import org.opensearch.timeseries.transport.DeleteModelNodeRequest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; import org.opensearch.transport.TransportService; import com.google.gson.JsonElement; @@ -53,7 +59,7 @@ import test.org.opensearch.ad.util.JsonDeserializer; public class DeleteModelTransportActionTests extends AbstractTimeSeriesTest { - private DeleteModelTransportAction action; + private DelegateDeleteADModelTransportAction action; private String localNodeID; @Override @@ -70,15 +76,15 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); NodeStateManager nodeStateManager = mock(NodeStateManager.class); - ModelManager modelManager = mock(ModelManager.class); + ADModelManager modelManager = mock(ADModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); - EntityCache entityCache = mock(EntityCache.class); + ADCacheProvider cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache entityCache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(entityCache); ADTaskCacheManager adTaskCacheManager = mock(ADTaskCacheManager.class); - EntityColdStarter coldStarter = mock(EntityColdStarter.class); + ADEntityColdStart coldStarter = mock(ADEntityColdStart.class); - action = new DeleteModelTransportAction( + action = new DelegateDeleteADModelTransportAction( threadPool, clusterService, transportService, diff --git a/src/test/java/org/opensearch/ad/transport/DeleteTests.java b/src/test/java/org/opensearch/ad/transport/DeleteTests.java index 4821cbfbd..9e5cff5df 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteTests.java @@ -58,6 +58,11 @@ import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; @@ -139,7 +144,7 @@ public void testSerialzationResponse() throws IOException { response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - DeleteModelResponse readResponse = DeleteModelAction.INSTANCE.getResponseReader().read(streamInput); + DeleteModelResponse readResponse = DeleteADModelAction.INSTANCE.getResponseReader().read(streamInput); assertTrue(readResponse.hasFailures()); assertEquals(failures.size(), readResponse.failures().size()); @@ -152,12 +157,12 @@ public void testEmptyIDDeleteModel() { } public void testEmptyIDStopDetector() { - ActionRequestValidationException e = new StopDetectorRequest().validate(); + ActionRequestValidationException e = new StopConfigRequest().validate(); assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); } public void testValidIDStopDetector() { - ActionRequestValidationException e = new StopDetectorRequest().adID("foo").validate(); + ActionRequestValidationException e = new StopConfigRequest().adID("foo").validate(); assertThat(e, is(nullValue())); } @@ -171,12 +176,12 @@ public void testSerialzationRequestDeleteModel() throws IOException { } public void testSerialzationRequestStopDetector() throws IOException { - StopDetectorRequest request = new StopDetectorRequest().adID("123"); + StopConfigRequest request = new StopConfigRequest().adID("123"); BytesStreamOutput output = new BytesStreamOutput(); request.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - StopDetectorRequest readRequest = new StopDetectorRequest(streamInput); - assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + StopConfigRequest readRequest = new StopConfigRequest(streamInput); + assertThat(request.getConfigID(), equalTo(readRequest.getConfigID())); } public void testJsonRequestTemplate(R request, Supplier requestSupplier) throws IOException, @@ -189,8 +194,8 @@ public void testJsonRequestTemplate(R request, Supplier listener = new PlainActionFuture<>(); + StopConfigRequest request = new StopConfigRequest().adID(detectorID); + PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(task, request, listener); - StopDetectorResponse response = listener.actionGet(); + StopConfigResponse response = listener.actionGet(); assertTrue(!response.success()); } diff --git a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java index f7eb2c8e9..77aa86e02 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java @@ -15,7 +15,6 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -50,27 +49,23 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.ad.AnomalyDetectorJobRunnerTests; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.constant.CommonValue; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -87,10 +82,16 @@ import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.ml.ModelState; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.transport.EntityResultRequest; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -99,14 +100,14 @@ import test.org.opensearch.ad.util.RandomModelStateConfig; public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { - EntityResultTransportAction entityResult; + EntityADResultTransportAction entityResult; ActionFilters actionFilters; TransportService transportService; - ModelManager manager; + ADModelManager manager; CircuitBreakerService adCircuitBreakerService; - CheckpointDao checkpointDao; - CacheProvider provider; - EntityCache entityCache; + ADCheckpointDao checkpointDao; + ADCacheProvider provider; + ADPriorityCache entityCache; NodeStateManager stateManager; Settings settings; Clock clock; @@ -125,13 +126,13 @@ public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { double[] cacheHitData; String tooLongEntity; double[] tooLongData; - ResultWriteWorker resultWriteQueue; - CheckpointReadWorker checkpointReadQueue; + ADResultWriteWorker resultWriteQueue; + ADCheckpointReadWorker checkpointReadQueue; int minSamples; Instant now; - EntityColdStarter coldStarter; - ColdEntityWorker coldEntityQueue; - EntityColdStartWorker entityColdStartQueue; + ADEntityColdStart coldStarter; + ADColdEntityWorker coldEntityQueue; + ADColdStartWorker entityColdStartQueue; ADIndexManagement indexUtil; ClusterService clusterService; ADStats adStats; @@ -157,14 +158,14 @@ public void setUp() throws Exception { adCircuitBreakerService = mock(CircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); - checkpointDao = mock(CheckpointDao.class); + checkpointDao = mock(ADCheckpointDao.class); detectorId = "123"; entities = new HashMap<>(); start = 10L; end = 20L; - request = new EntityResultRequest(detectorId, entities, start, end); + request = new EntityResultRequest(detectorId, entities, start, end, AnalysisType.AD, null); clock = mock(Clock.class); now = Instant.now(); @@ -182,7 +183,7 @@ public void setUp() throws Exception { Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - manager = new ModelManager( + manager = new ADModelManager( null, clock, 0, @@ -193,15 +194,15 @@ public void setUp() throws Exception { 0, null, AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, - mock(EntityColdStarter.class), + mock(ADEntityColdStart.class), null, null, settings, clusterService ); - provider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + provider = mock(ADCacheProvider.class); + entityCache = mock(ADPriorityCache.class); when(provider.get()).thenReturn(entityCache); String field = "a"; @@ -225,7 +226,8 @@ public void setUp() throws Exception { tooLongData = new double[] { 0.3 }; entities.put(Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), tooLongEntity), tooLongData); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(entityCache.get(eq(cacheMissEntityObj.getModelId(detectorId).get()), any())).thenReturn(null); when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); @@ -236,31 +238,31 @@ public void setUp() throws Exception { indexUtil = mock(ADIndexManagement.class); when(indexUtil.getSchemaVersion(any())).thenReturn(CommonValue.NO_SCHEMA_VERSION); - resultWriteQueue = mock(ResultWriteWorker.class); - checkpointReadQueue = mock(CheckpointReadWorker.class); + resultWriteQueue = mock(ADResultWriteWorker.class); + checkpointReadQueue = mock(ADCheckpointReadWorker.class); minSamples = 1; - coldStarter = mock(EntityColdStarter.class); + coldStarter = mock(ADEntityColdStart.class); doAnswer(invocation -> { - ModelState modelState = invocation.getArgument(0); - modelState.getModel().clear(); + ModelState modelState = invocation.getArgument(0); + modelState.clear(); return null; - }).when(coldStarter).trainModelFromExistingSamples(any(), anyInt()); + }).when(coldStarter).trainModelFromExistingSamples(any(), any(), any(), any()); - coldEntityQueue = mock(ColdEntityWorker.class); - entityColdStartQueue = mock(EntityColdStartWorker.class); + coldEntityQueue = mock(ADColdEntityWorker.class); + entityColdStartQueue = mock(ADColdStartWorker.class); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); - entityResult = new EntityResultTransportAction( + entityResult = new EntityADResultTransportAction( actionFilters, transportService, manager, @@ -273,7 +275,8 @@ public void setUp() throws Exception { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); // timeout in 60 seconds @@ -317,7 +320,8 @@ public void testFailtoGetDetector() { // test rcf score is 0 public void testNoResultsToSave() { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -335,19 +339,19 @@ public void testValidRequest() { } public void testEmptyId() { - request = new EntityResultRequest("", entities, start, end); + request = new EntityResultRequest("", entities, start, end, AnalysisType.AD, null); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); } public void testReverseTime() { - request = new EntityResultRequest(detectorId, entities, end, start); + request = new EntityResultRequest(detectorId, entities, end, start, AnalysisType.AD, null); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); } public void testNegativeTime() { - request = new EntityResultRequest(detectorId, entities, start, -end); + request = new EntityResultRequest(detectorId, entities, start, -end, AnalysisType.AD, null); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); } @@ -384,9 +388,9 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { } public void testFailToScore() { - ModelManager spyModelManager = spy(manager); - doThrow(new IllegalArgumentException()).when(spyModelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); - entityResult = new EntityResultTransportAction( + ADModelManager spyModelManager = spy(manager); + doThrow(new IllegalArgumentException()).when(spyModelManager).getResult(any(), any(), anyString(), any(), any(), any()); + entityResult = new EntityADResultTransportAction( actionFilters, transportService, spyModelManager, @@ -399,7 +403,8 @@ public void testFailToScore() { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -409,9 +414,9 @@ public void testFailToScore() { future.actionGet(timeoutMs); verify(resultWriteQueue, never()).put(any()); - verify(entityCache, times(1)).removeEntityModel(anyString(), anyString()); + verify(entityCache, times(1)).removeModel(anyString(), anyString()); verify(entityColdStartQueue, times(1)).put(any()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java index 633a9a4fe..d8a8b1c75 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java @@ -78,7 +78,8 @@ public void testNullDetectorIdAndTaskAction() throws IOException { null, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); ForwardADTaskRequest request = new ForwardADTaskRequest(detector, null, null, null, null, Version.V_2_1_0); ActionRequestValidationException validate = request.validate(); diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java index a45fb6779..fdbbbae49 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java @@ -31,15 +31,17 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.core.action.ActionListener; import org.opensearch.tasks.Task; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; @@ -53,7 +55,7 @@ public class ForwardADTaskTransportActionTests extends ADUnitTestCase { private NodeStateManager stateManager; private ForwardADTaskTransportAction forwardADTaskTransportAction; private Task task; - private ActionListener listener; + private ActionListener listener; @SuppressWarnings("unchecked") @Override @@ -71,7 +73,8 @@ public void setUp() throws Exception { adTaskManager, adTaskCacheManager, featureManager, - stateManager + stateManager, + mock(ADIndexJobActionHandler.class) ); task = mock(Task.class); diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java index 2a0b677ed..d41f255f4 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java @@ -31,6 +31,7 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.GetConfigRequest; import com.google.common.collect.ImmutableList; @@ -53,11 +54,11 @@ protected NamedWriteableRegistry writableRegistry() { public void testGetRequest() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, false, false, "nonempty", "", false, null); + GetConfigRequest request = new GetConfigRequest("1234", 4321, false, false, "nonempty", "", false, null); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + GetConfigRequest newRequest = new GetConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java index 1181b685f..7aed2eae0 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java @@ -21,7 +21,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; @@ -38,7 +37,7 @@ import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -50,11 +49,14 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.index.get.GetResult; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.EntityProfileTests; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.Transport; @@ -67,7 +69,7 @@ public class GetAnomalyDetectorTests extends AbstractTimeSeriesTest { private ActionFilters actionFilters; private Client client; private SecurityClientUtil clientUtil; - private GetAnomalyDetectorRequest request; + private GetConfigRequest request; private String detectorId = "yecrdnUBqurvo9uKU_d8"; private String entityValue = "app_0"; private String categoryField = "categoryField"; @@ -104,7 +106,8 @@ public void setUp() throws Exception { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); nodeFilter = mock(DiscoveryNodeFilterer.class); @@ -119,6 +122,8 @@ public void setUp() throws Exception { adTaskManager = mock(ADTaskManager.class); + ADTaskProfileRunner adTaskProfileRunner = mock(ADTaskProfileRunner.class); + action = new GetAnomalyDetectorTransportAction( transportService, nodeFilter, @@ -128,26 +133,27 @@ public void setUp() throws Exception { clientUtil, Settings.EMPTY, xContentRegistry(), - adTaskManager + adTaskManager, + adTaskProfileRunner ); entity = Entity.createSingleAttributeEntity(categoryField, entityValue); } - public void testInvalidRequest() throws IOException { + public void testInvalidRequest() { typeStr = "entity_info2,init_progress2"; rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); - assertException(future, OpenSearchStatusException.class, ADCommonMessages.EMPTY_PROFILES_COLLECT); + assertException(future, OpenSearchStatusException.class, CommonMessages.EMPTY_PROFILES_COLLECT); } @SuppressWarnings("unchecked") - public void testValidRequest() throws IOException { + public void testValidRequest() { doAnswer(invocation -> { Object[] args = invocation.getArguments(); GetRequest request = (GetRequest) args[0]; @@ -164,7 +170,7 @@ public void testValidRequest() throws IOException { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); @@ -180,17 +186,7 @@ public void testGetTransportActionWithReturnTask() { return null; }) .when(adTaskManager) - .getAndExecuteOnLatestADTasks( - anyString(), - eq(null), - eq(null), - anyList(), - any(), - eq(transportService), - eq(true), - anyInt(), - any() - ); + .getAndExecuteOnLatestTasks(anyString(), eq(null), eq(null), anyList(), any(), eq(transportService), eq(true), anyInt(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -202,7 +198,7 @@ public void testGetTransportActionWithReturnTask() { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, true, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, true, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.getExecute(request, future); diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java index 35f6ba36f..1c12aeb08 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java @@ -25,14 +25,11 @@ import org.junit.*; import org.mockito.Mockito; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.EntityProfile; -import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.*; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; @@ -50,8 +47,12 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfile; +import org.opensearch.timeseries.model.InitProgressProfile; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.timeseries.util.SecurityClientUtil; @@ -101,7 +102,8 @@ public void setUp() throws Exception { clientUtil, Settings.EMPTY, xContentRegistry(), - adTaskManager + adTaskManager, + mock(ADTaskProfileRunner.class) ); task = Mockito.mock(Task.class); response = new ActionListener() { @@ -126,14 +128,14 @@ protected NamedWriteableRegistry writableRegistry() { @Test public void testGetTransportAction() throws IOException { - GetAnomalyDetectorRequest getConfigRequest = new GetAnomalyDetectorRequest("1234", 4321, false, false, "nonempty", "", false, null); - action.doExecute(task, getConfigRequest, response); + GetConfigRequest getAnomalyDetectorRequest = new GetConfigRequest("1234", 4321, false, false, "nonempty", "", false, null); + action.doExecute(task, getAnomalyDetectorRequest, response); } @Test public void testGetTransportActionWithReturnJob() throws IOException { - GetAnomalyDetectorRequest getConfigRequest = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, null); - action.doExecute(task, getConfigRequest, response); + GetConfigRequest getAnomalyDetectorRequest = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, null); + action.doExecute(task, getAnomalyDetectorRequest, response); } @Test @@ -144,23 +146,23 @@ public void testGetAction() { @Test public void testGetAnomalyDetectorRequest() throws IOException { - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, entity); + GetConfigRequest request = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, entity); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + GetConfigRequest newRequest = new GetConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); Assert.assertEquals(request.getRawPath(), newRequest.getRawPath()); Assert.assertNull(newRequest.validate()); } @Test public void testGetAnomalyDetectorRequestNoEntityValue() throws IOException { - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, null); + GetConfigRequest request = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, null); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); + GetConfigRequest newRequest = new GetConfigRequest(input); Assert.assertNull(newRequest.getEntity()); } @@ -230,7 +232,7 @@ public void testGetAnomalyDetectorProfileResponse() throws IOException { // {init_progress={percentage=99%, estimated_minutes_left=2, needed_shingles=2}} Map map = TestHelpers.XContentBuilderToMap(builder); - Map parsedInitProgress = (Map) (map.get(ADCommonName.INIT_PROGRESS)); + Map parsedInitProgress = (Map) (map.get(CommonName.INIT_PROGRESS)); Assert.assertEquals(initProgress.getPercentage(), parsedInitProgress.get(InitProgressProfile.PERCENTAGE).toString()); assertTrue(initProgress.toString().contains("[percentage=99%,estimated_minutes_left=2,needed_shingles=2]")); Assert diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java index f29030912..e4f160aa1 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java @@ -34,6 +34,7 @@ import com.google.common.collect.ImmutableMap; public class IndexAnomalyDetectorActionTests extends OpenSearchSingleNodeTestCase { + @Override @Before public void setUp() throws Exception { super.setUp(); @@ -58,7 +59,8 @@ public void testIndexRequest() throws Exception { TimeValue.timeValueSeconds(60), 1000, 10, - 5 + 5, + 10 ); request.writeTo(out); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java index d370fa703..c59108c17 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java @@ -176,7 +176,8 @@ public void setUp() throws Exception { TimeValue.timeValueSeconds(60), 1000, 10, - 5 + 5, + 10 ); response = new ActionListener() { @Override diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index 94e07fe3c..6f3d43a8c 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -13,7 +13,6 @@ import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -66,24 +65,19 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.feature.CompositeRetriever; -import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.EntityFeatureRequest; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.ratelimit.ADSaveResultStrategy; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -113,15 +107,22 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.CompositeRetriever; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.FeatureRequest; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.transport.ResultProcessor; import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.Transport; @@ -149,7 +150,7 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private Client client; private SecurityClientUtil clientUtil; private FeatureManager featureQuery; - private ModelManager normalModelManager; + private ADModelManager normalModelManager; private HashRing hashRing; private ClusterService clusterService; private IndexNameExpressionResolver indexNameResolver; @@ -158,12 +159,12 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private ThreadPool mockThreadPool; private String detectorId; private Instant now; - private CacheProvider provider; + private ADCacheProvider provider; private ADIndexManagement indexUtil; - private ResultWriteWorker resultWriteQueue; - private CheckpointReadWorker checkpointReadQueue; - private EntityColdStartWorker entityColdStartQueue; - private ColdEntityWorker coldEntityQueue; + private ADResultWriteWorker resultWriteQueue; + private ADCheckpointReadWorker checkpointReadQueue; + private ADColdStartWorker entityColdStartQueue; + private ADColdEntityWorker coldEntityQueue; private String app0 = "app_0"; private String server1 = "server_1"; private String server2 = "server_2"; @@ -171,7 +172,7 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private String serviceField = "service"; private String hostField = "host"; private Map attrs1, attrs2, attrs3; - private EntityCache entityCache; + private ADPriorityCache entityCache; private ADTaskManager adTaskManager; @BeforeClass @@ -222,7 +223,7 @@ public void setUp() throws Exception { featureQuery = mock(FeatureManager.class); - normalModelManager = mock(ModelManager.class); + normalModelManager = mock(ADModelManager.class); hashRing = mock(HashRing.class); @@ -248,13 +249,13 @@ public void setUp() throws Exception { adCircuitBreakerService = mock(CircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); @@ -292,19 +293,19 @@ public void setUp() throws Exception { adTaskManager ); - provider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + provider = mock(ADCacheProvider.class); + entityCache = mock(ADPriorityCache.class); when(provider.get()).thenReturn(entityCache); when(entityCache.get(any(), any())) .thenReturn(MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build())); when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(new ArrayList(), new ArrayList())); indexUtil = mock(ADIndexManagement.class); - resultWriteQueue = mock(ResultWriteWorker.class); - checkpointReadQueue = mock(CheckpointReadWorker.class); - entityColdStartQueue = mock(EntityColdStartWorker.class); + resultWriteQueue = mock(ADResultWriteWorker.class); + checkpointReadQueue = mock(ADCheckpointReadWorker.class); + entityColdStartQueue = mock(ADColdStartWorker.class); - coldEntityQueue = mock(ColdEntityWorker.class); + coldEntityQueue = mock(ADColdEntityWorker.class); attrs1 = new HashMap<>(); attrs1.put(serviceField, app0); @@ -328,17 +329,17 @@ public final void tearDown() throws Exception { public void testColdStartEndRunException() { when(stateManager.fetchExceptionAndClear(anyString())) - .thenReturn( + .thenReturn( Optional - .of( + .of( new EndRunException( - detectorId, - CommonMessages.INVALID_SEARCH_QUERY_MSG, - new NoSuchElementException("No value present"), - false + detectorId, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) ) - ) - ); + ); PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(null, request, listener); assertException(listener, EndRunException.class, CommonMessages.INVALID_SEARCH_QUERY_MSG); @@ -397,7 +398,7 @@ public String executor() { private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) { // register entity result action - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[nodeIndex].transportService, @@ -411,11 +412,11 @@ private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); - when(normalModelManager.getAnomalyResultForEntity(any(), any(), any(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 1, 1)); + when(normalModelManager.getResult(any(), any(), any(), any(), any(), any())).thenReturn(new ThresholdingResult(0, 1, 1)); } private void setUpEntityResult(int nodeIndex) { @@ -430,7 +431,7 @@ public void setUpNormlaStateManager() throws IOException { .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, CommonName.CONFIG_INDEX)); return null; }).when(client).get(any(GetRequest.class), any(ActionListener.class)); @@ -538,11 +539,7 @@ public void testIndexNotFound() throws InterruptedException, IOException { PlainActionFuture listener2 = new PlainActionFuture<>(); action.doExecute(null, request, listener2); Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); - assertThat( - "actual message: " + e.getMessage(), - e.getMessage(), - containsString(AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG) - ); + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(ResultProcessor.TROUBLE_QUERYING_ERR_MSG)); assertTrue(!((EndRunException) e).isEndNow()); } @@ -661,7 +658,7 @@ public void sendRequest( TransportRequestOptions options, TransportResponseHandler handler ) { - if (action.equals(EntityResultAction.NAME)) { + if (action.equals(EntityADResultAction.NAME)) { sender .sendRequest( connection, @@ -715,7 +712,7 @@ public void testNonEmptyFeatures() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -766,14 +763,14 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler, spyStateManager); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); CircuitBreakerService openBreaker = mock(CircuitBreakerService.class); when(openBreaker.isOpen()).thenReturn(true); // register entity result action - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, @@ -787,7 +784,8 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); CountDownLatch inProgress = new CountDownLatch(1); @@ -816,7 +814,7 @@ public void testNotAck() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::unackEntityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -847,13 +845,13 @@ public void testMultipleNode() throws InterruptedException, IOException { Entity entity3 = Entity.createEntityByReordering(attrs3); // we use ordered attributes values as the key to hashring - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity1.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity1.toString()))) .thenReturn(Optional.of(testNodes[2].discoveryNode())); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity2.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity2.toString()))) .thenReturn(Optional.of(testNodes[3].discoveryNode())); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity3.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity3.toString()))) .thenReturn(Optional.of(testNodes[4].discoveryNode())); for (int i = 2; i <= 4; i++) { @@ -883,7 +881,7 @@ public void testCacheSelectionError() throws IOException, InterruptedException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); setUpEntityResult(1); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); List hotEntities = new ArrayList<>(); @@ -916,21 +914,21 @@ public void testCacheSelectionError() throws IOException, InterruptedException { assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); // size 0 because cacheMissEntities has no record of these entities - verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { + verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size()); return arg.size() == 0; } })); - verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { + verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size()); return arg.size() == 0; } @@ -940,7 +938,7 @@ public boolean matches(List argument) { public void testCacheSelection() throws IOException, InterruptedException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); List hotEntities = new ArrayList<>(); @@ -951,13 +949,13 @@ public void testCacheSelection() throws IOException, InterruptedException { Entity entity2 = Entity.createEntityByReordering(attrs2); coldEntities.add(entity2); - provider = mock(CacheProvider.class); - entityCache = mock(EntityCache.class); + provider = mock(ADCacheProvider.class); + entityCache = mock(ADPriorityCache.class); when(provider.get()).thenReturn(entityCache); when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(hotEntities, coldEntities)); when(entityCache.get(any(), any())).thenReturn(null); - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, @@ -971,7 +969,8 @@ public void testCacheSelection() throws IOException, InterruptedException { coldEntityQueue, threadPool, entityColdStartQueue, - adStats + adStats, + mock(ADSaveResultStrategy.class) ); CountDownLatch modelNodeInProgress = new CountDownLatch(1); @@ -987,21 +986,21 @@ public void testCacheSelection() throws IOException, InterruptedException { action.doExecute(null, request, listener); assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); - verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { + verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size() + " ; element: " + arg.get(0)); return arg.size() == 1 && arg.get(0).getEntity().equals(entity1); } })); - verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { + verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { @Override - public boolean matches(List argument) { - List arg = (argument); + public boolean matches(List argument) { + List arg = (argument); LOG.info("size: " + arg.size() + " ; element: " + arg.get(0)); return arg.size() == 1 && arg.get(0).getEntity().equals(entity2); } @@ -1131,7 +1130,7 @@ public void testRetry() throws IOException, InterruptedException { }).when(coldEntityQueue).putAll(any()); setUpTransportInterceptor(this::entityResultHandler); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -1166,7 +1165,8 @@ public void testPageToString() { 10000, 1000, indexNameResolver, - clusterService + clusterService, + AnalysisType.AD ); Map results = new HashMap<>(); Entity entity1 = Entity.createEntityByReordering(attrs1); @@ -1193,7 +1193,8 @@ public void testEmptyPageToString() { 10000, 1000, indexNameResolver, - clusterService + clusterService, + AnalysisType.AD ); CompositeRetriever.Page page = retriever.new Page(null); @@ -1207,7 +1208,7 @@ private NodeStateManager setUpTestExceptionTestingInModelNode() throws IOExcepti setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); NodeStateManager modelNodeStateManager = mock(NodeStateManager.class); diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java index 38cdce966..f80d3630d 100644 --- a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java @@ -45,10 +45,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.AnomalyDetectorRunner; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.Features; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -74,6 +72,8 @@ import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; @@ -85,7 +85,7 @@ public class PreviewAnomalyDetectorTransportActionTests extends OpenSearchSingle private AnomalyDetectorRunner runner; private ClusterService clusterService; private FeatureManager featureManager; - private ModelManager modelManager; + private ADModelManager modelManager; private Task task; private CircuitBreakerService circuitBreaker; @@ -127,7 +127,7 @@ public void setUp() throws Exception { when(clusterService.state()).thenReturn(clusterState); featureManager = mock(FeatureManager.class); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); runner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); circuitBreaker = mock(CircuitBreakerService.class); when(circuitBreaker.isOpen()).thenReturn(false); diff --git a/src/test/java/org/opensearch/ad/transport/ProfileITTests.java b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java index 013f00097..6bd8aa326 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileITTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java @@ -16,10 +16,12 @@ import java.util.HashSet; import java.util.concurrent.ExecutionException; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; public class ProfileITTests extends OpenSearchIntegTestCase { @@ -33,9 +35,9 @@ protected Collection> transportClientPlugins() { } public void testNormalProfile() throws ExecutionException, InterruptedException { - ProfileRequest profileRequest = new ProfileRequest("123", new HashSet(), false); + ProfileRequest profileRequest = new ProfileRequest("123", new HashSet(), false); - ProfileResponse response = client().execute(ProfileAction.INSTANCE, profileRequest).get(); + ProfileResponse response = client().execute(ADProfileAction.INSTANCE, profileRequest).get(); assertTrue("getting profile failed", !response.hasFailures()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTests.java index 7df0d5e02..46728f79e 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTests.java @@ -30,9 +30,6 @@ import org.opensearch.Version; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -42,6 +39,12 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfileOnNode; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -112,11 +115,11 @@ public void setUp() throws Exception { @Test public void testProfileNodeRequest() throws IOException { - Set profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); ProfileRequest ProfileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false); ProfileNodeRequest ProfileNodeRequest = new ProfileNodeRequest(ProfileRequest); - assertEquals("ProfileNodeRequest has the wrong detector id", ProfileNodeRequest.getId(), detectorId); + assertEquals("ProfileNodeRequest has the wrong detector id", ProfileNodeRequest.getConfigId(), detectorId); assertEquals("ProfileNodeRequest has the wrong ProfileRequest", ProfileNodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); // Test serialization @@ -124,7 +127,7 @@ public void testProfileNodeRequest() throws IOException { ProfileNodeRequest.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); ProfileNodeRequest nodeRequest = new ProfileNodeRequest(streamInput); - assertEquals("serialization has the wrong detector id", nodeRequest.getId(), detectorId); + assertEquals("serialization has the wrong detector id", nodeRequest.getConfigId(), detectorId); assertEquals("serialization has the wrong ProfileRequest", nodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); } @@ -162,14 +165,14 @@ public void testProfileNodeResponse() throws IOException, JsonPathNotFoundExcept ); } - assertEquals("toXContent has the wrong shingle size", JsonDeserializer.getIntValue(json, ADCommonName.SHINGLE_SIZE), shingleSize); + assertEquals("toXContent has the wrong shingle size", JsonDeserializer.getIntValue(json, CommonName.SHINGLE_SIZE), shingleSize); } @Test public void testProfileRequest() throws IOException { String detectorId = "123"; - Set profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false); // Test Serialization @@ -182,7 +185,7 @@ public void testProfileRequest() throws IOException { readRequest.getProfilesToBeRetrieved(), profileRequest.getProfilesToBeRetrieved() ); - assertEquals("Serialization has the wrong detector id", readRequest.getId(), profileRequest.getId()); + assertEquals("Serialization has the wrong detector id", readRequest.getConfigId(), profileRequest.getConfigId()); } @Test diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java index bccd385bb..36d2b29f8 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java @@ -29,35 +29,39 @@ import org.junit.Test; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ProfileName; +import org.opensearch.timeseries.transport.ProfileNodeRequest; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileRequest; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.transport.TransportService; public class ProfileTransportActionTests extends OpenSearchIntegTestCase { - private ProfileTransportAction action; + private DelegateADProfileTransportAction action; private String detectorId = "Pl536HEBnXkDrah03glg"; String node1, nodeName1; DiscoveryNode discoveryNode1; - Set profilesToRetrieve = new HashSet(); + Set profilesToRetrieve = new HashSet(); private int shingleSize = 6; private long modelSize = 4456448L; private String modelId = "Pl536HEBnXkDrah03glg_model_rcf_1"; - private CacheProvider cacheProvider; + private ADCacheProvider cacheProvider; private int activeEntities = 10; private long totalUpdates = 127; private long multiEntityModelSize = 712480L; - private ModelManager modelManager; + private ADModelManager modelManager; private FeatureManager featureManager; @Override @@ -65,13 +69,13 @@ public class ProfileTransportActionTests extends OpenSearchIntegTestCase { public void setUp() throws Exception { super.setUp(); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); featureManager = mock(FeatureManager.class); when(featureManager.getShingleSize(any(String.class))).thenReturn(shingleSize); - EntityCache cache = mock(EntityCache.class); - cacheProvider = mock(CacheProvider.class); + ADPriorityCache cache = mock(ADPriorityCache.class); + cacheProvider = mock(ADCacheProvider.class); when(cacheProvider.get()).thenReturn(cache); when(cache.getActiveEntities(anyString())).thenReturn(activeEntities); when(cache.getTotalUpdates(anyString())).thenReturn(totalUpdates); @@ -98,7 +102,7 @@ public void setUp() throws Exception { Settings settings = Settings.builder().put("plugins.anomaly_detection.max_model_size_per_node", 100).build(); - action = new ProfileTransportAction( + action = new DelegateADProfileTransportAction( client().threadPool(), clusterService(), mock(TransportService.class), @@ -109,8 +113,8 @@ public void setUp() throws Exception { settings ); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.COORDINATING_NODE); } private void setUpModelSize(int maxModel) { @@ -145,7 +149,7 @@ public void testNewNodeRequest() { ProfileNodeRequest profileNodeRequest1 = new ProfileNodeRequest(profileRequest); ProfileNodeRequest profileNodeRequest2 = action.newNodeRequest(profileRequest); - assertEquals(profileNodeRequest1.getId(), profileNodeRequest2.getId()); + assertEquals(profileNodeRequest1.getConfigId(), profileNodeRequest2.getConfigId()); assertEquals(profileNodeRequest2.getProfilesToBeRetrieved(), profileNodeRequest2.getProfilesToBeRetrieved()); } @@ -160,8 +164,8 @@ public void testNodeOperation() { assertEquals(shingleSize, response.getShingleSize()); assertEquals(null, response.getModelSize()); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.TOTAL_SIZE_IN_BYTES); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.TOTAL_SIZE_IN_BYTES); profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false, nodeId); response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -175,8 +179,8 @@ public void testNodeOperation() { public void testMultiEntityNodeOperation() { setUpModelSize(100); DiscoveryNode nodeId = clusterService().localNode(); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.ACTIVE_ENTITIES); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.ACTIVE_ENTITIES); ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -184,7 +188,7 @@ public void testMultiEntityNodeOperation() { assertEquals(activeEntities, response.getActiveEntities()); assertEquals(null, response.getModelSize()); - profilesToRetrieve.add(DetectorProfileName.INIT_PROGRESS); + profilesToRetrieve.add(ProfileName.INIT_PROGRESS); profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -193,7 +197,7 @@ public void testMultiEntityNodeOperation() { assertEquals(null, response.getModelSize()); assertEquals(totalUpdates, response.getTotalUpdates()); - profilesToRetrieve.add(DetectorProfileName.MODELS); + profilesToRetrieve.add(ProfileName.MODELS); profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); @@ -210,7 +214,7 @@ public void testModelCount() { Settings settings = Settings.builder().put("plugins.anomaly_detection.max_model_size_per_node", 1).build(); - action = new ProfileTransportAction( + action = new DelegateADProfileTransportAction( client().threadPool(), clusterService(), mock(TransportService.class), @@ -222,8 +226,8 @@ public void testModelCount() { ); DiscoveryNode nodeId = clusterService().localNode(); - profilesToRetrieve = new HashSet(); - profilesToRetrieve.add(DetectorProfileName.MODELS); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(ProfileName.MODELS); ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); assertEquals(2, response.getModelCount()); diff --git a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java index 8cb592927..13443b596 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java @@ -28,10 +28,9 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -41,8 +40,10 @@ import org.opensearch.core.transport.TransportResponse; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.tasks.Task; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.transport.ConnectTransportException; @@ -69,7 +70,7 @@ public class RCFPollingTests extends AbstractTimeSeriesTest { private ClusterService clusterService; private HashRing hashRing; private TransportAddress transportAddress1; - private ModelManager manager; + private ADModelManager manager; private TransportService transportService; private PlainActionFuture future; private RCFPollingTransportAction action; @@ -104,7 +105,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); hashRing = mock(HashRing.class); transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); - manager = mock(ModelManager.class); + manager = mock(ADModelManager.class); transportService = new TransportService( Settings.EMPTY, mock(Transport.class), @@ -112,7 +113,8 @@ public void setUp() throws Exception { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); future = new PlainActionFuture<>(); @@ -189,7 +191,7 @@ public void testDoubleNaN() { public void testNormal() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); @@ -208,15 +210,15 @@ public void testNormal() { } public void testNoNodeFoundForModel() { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(Optional.empty()); action = new RCFPollingTransportAction( - mock(ActionFilters.class), - transportService, - Settings.EMPTY, - manager, - hashRing, - clusterService - ); + mock(ActionFilters.class), + transportService, + Settings.EMPTY, + manager, + hashRing, + clusterService + ); action.doExecute(mock(Task.class), request, future); assertException(future, TimeSeriesException.class, RCFPollingTransportAction.NO_NODE_FOUND_MSG); } @@ -305,7 +307,7 @@ public void testGetRemoteNormalResponse() { clusterService ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -333,7 +335,7 @@ public void testGetRemoteFailureResponse() { clusterService ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java index 23e5db59c..a3ed6ee7c 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java @@ -37,15 +37,12 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -54,10 +51,14 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.tasks.Task; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; @@ -87,10 +88,10 @@ public void setUp() throws Exception { hashRing = mock(HashRing.class); node = mock(DiscoveryNode.class); doReturn(Optional.of(node)).when(hashRing).getNodeByAddress(any()); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; @@ -106,10 +107,11 @@ public void testNormal() { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -164,10 +166,11 @@ public void testExecutionException() { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -280,10 +283,11 @@ public void testCircuitBreaker() { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService breakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -331,10 +335,11 @@ public void testCorruptModel() { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), @@ -358,7 +363,7 @@ public void testCorruptModel() { action.doExecute(mock(Task.class), request, future); expectThrows(IllegalArgumentException.class, () -> future.actionGet()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); verify(manager, times(1)).clear(eq(detectorId), any()); } diff --git a/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java index bc87faf13..65b0ee95d 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java @@ -24,13 +24,13 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.HistoricalAnalysisIntegTestCase; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ADTask; import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.model.TimeSeriesTask; @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) public class SearchADTasksTransportActionTests extends HistoricalAnalysisIntegTestCase { @@ -81,7 +81,7 @@ public void testSearchWithExistingTask() throws IOException { private SearchRequest searchRequest(boolean isLatest) { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(ADTask.IS_LATEST_FIELD, isLatest)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, isLatest)); sourceBuilder.query(query); SearchRequest request = new SearchRequest().source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); return request; diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java b/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java index 877fcd887..df7844ef1 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java @@ -58,6 +58,7 @@ import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.bucket.terms.TermsAggregator; import org.opensearch.tasks.Task; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; import org.opensearch.transport.Transport; @@ -101,7 +102,8 @@ public void setUp() throws Exception { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); client = mock(Client.class); diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java index 796d492e1..6a0dd234b 100644 --- a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java @@ -21,7 +21,6 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; @@ -30,6 +29,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.StatsNodeResponse; +import org.opensearch.timeseries.transport.StatsNodesResponse; +import org.opensearch.timeseries.transport.StatsResponse; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; public class StatsAnomalyDetectorActionTests extends OpenSearchTestCase { @@ -47,20 +50,20 @@ public void testStatsAction() { @Test public void testStatsResponse() throws IOException { - ADStatsResponse adStatsResponse = new ADStatsResponse(); + StatsResponse adStatsResponse = new StatsResponse(); Map testClusterStats = new HashMap<>(); testClusterStats.put("test_response", 1); adStatsResponse.setClusterStats(testClusterStats); - List responses = Collections.emptyList(); + List responses = Collections.emptyList(); List failures = Collections.emptyList(); - ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); - adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); + StatsNodesResponse adStatsNodesResponse = new StatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setStatsNodesResponse(adStatsNodesResponse); - StatsAnomalyDetectorResponse response = new StatsAnomalyDetectorResponse(adStatsResponse); + StatsTimeSeriesResponse response = new StatsTimeSeriesResponse(adStatsResponse); BytesStreamOutput out = new BytesStreamOutput(); response.writeTo(out); StreamInput input = out.bytes().streamInput(); - StatsAnomalyDetectorResponse newResponse = new StatsAnomalyDetectorResponse(input); + StatsTimeSeriesResponse newResponse = new StatsTimeSeriesResponse(input); assertNotNull(newResponse); XContentBuilder builder = XContentFactory.jsonBuilder(); diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java index 7c877c086..ceff494ed 100644 --- a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java @@ -16,9 +16,11 @@ import org.junit.Before; import org.opensearch.ad.ADIntegTestCase; -import org.opensearch.ad.stats.InternalStatNames; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.stats.InternalStatNames; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.StatsRequest; +import org.opensearch.timeseries.transport.StatsTimeSeriesResponse; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -47,14 +49,14 @@ public void setUp() throws Exception { } public void testStatsAnomalyDetectorWithNodeLevelStats() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); adStatsRequest.addStat(InternalStatNames.JVM_HEAP_USAGE.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); assertTrue( response .getAdStatsResponse() - .getADStatsNodesResponse() + .getStatsNodesResponse() .getNodes() .get(0) .getStatsMap() @@ -63,39 +65,39 @@ public void testStatsAnomalyDetectorWithNodeLevelStats() { } public void testStatsAnomalyDetectorWithClusterLevelStats() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); adStatsRequest.addStat(StatNames.DETECTOR_COUNT.getName()); - adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); - Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + adStatsRequest.addStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); - assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); } public void testStatsAnomalyDetectorWithDetectorCount() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); adStatsRequest.addStat(StatNames.DETECTOR_COUNT.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); - Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); - assertFalse(clusterStats.containsKey(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertFalse(clusterStats.containsKey(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); } public void testStatsAnomalyDetectorWithSingleEntityDetectorCount() { - ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); - adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); - StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); - assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); - Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + StatsRequest adStatsRequest = new StatsRequest(clusterService().localNode()); + adStatsRequest.addStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()); + StatsTimeSeriesResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); - assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); assertFalse(clusterStats.containsKey(StatNames.DETECTOR_COUNT.getName())); } diff --git a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java index a2e98bf88..786d34cea 100644 --- a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java @@ -26,6 +26,8 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; public class StopDetectorActionTests extends OpenSearchIntegTestCase { @@ -43,7 +45,7 @@ public void testStopDetectorAction() { @Test public void fromActionRequest_Success() { - StopDetectorRequest stopDetectorRequest = new StopDetectorRequest("adID"); + StopConfigRequest stopDetectorRequest = new StopConfigRequest("adID"); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -55,41 +57,41 @@ public void writeTo(StreamOutput out) throws IOException { stopDetectorRequest.writeTo(out); } }; - StopDetectorRequest result = StopDetectorRequest.fromActionRequest(actionRequest); + StopConfigRequest result = StopConfigRequest.fromActionRequest(actionRequest); assertNotSame(result, stopDetectorRequest); - assertEquals(result.getAdID(), stopDetectorRequest.getAdID()); + assertEquals(result.getConfigID(), stopDetectorRequest.getConfigID()); } @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - StopDetectorResponse response = new StopDetectorResponse(true); + StopConfigResponse response = new StopConfigResponse(true); response.writeTo(bytesStreamOutput); - StopDetectorResponse parsedResponse = new StopDetectorResponse(bytesStreamOutput.bytes().streamInput()); + StopConfigResponse parsedResponse = new StopConfigResponse(bytesStreamOutput.bytes().streamInput()); assertNotEquals(response, parsedResponse); assertEquals(response.success(), parsedResponse.success()); } @Test public void fromActionResponse_Success() throws IOException { - StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + StopConfigResponse stopDetectorResponse = new StopConfigResponse(true); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput streamOutput) throws IOException { stopDetectorResponse.writeTo(streamOutput); } }; - StopDetectorResponse result = stopDetectorResponse.fromActionResponse(actionResponse); + StopConfigResponse result = stopDetectorResponse.fromActionResponse(actionResponse); assertNotSame(result, stopDetectorResponse); assertEquals(result.success(), stopDetectorResponse.success()); - StopDetectorResponse parsedStopDetectorResponse = stopDetectorResponse.fromActionResponse(stopDetectorResponse); + StopConfigResponse parsedStopDetectorResponse = stopDetectorResponse.fromActionResponse(stopDetectorResponse); assertEquals(parsedStopDetectorResponse, stopDetectorResponse); } @Test public void toXContentTest() throws IOException { - StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + StopConfigResponse stopDetectorResponse = new StopConfigResponse(true); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); stopDetectorResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java index 410274e37..20c1b06ef 100644 --- a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java @@ -29,7 +29,7 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -38,6 +38,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.tasks.Task; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; @@ -55,10 +56,11 @@ public void testNormal() { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -84,10 +86,11 @@ public void testExecutionException() { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); doThrow(NullPointerException.class) .when(manager) diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java index 4a1fae9cb..5f7a4ede4 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java @@ -21,7 +21,9 @@ import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.transport.ValidateConfigRequest; import com.google.common.collect.ImmutableMap; @@ -38,14 +40,14 @@ public void testValidateAnomalyDetectorRequestSerialization() throws IOException TimeValue requestTimeout = new TimeValue(1000L); String typeStr = "type"; - ValidateAnomalyDetectorRequest request1 = new ValidateAnomalyDetectorRequest(detector, typeStr, 1, 1, 1, requestTimeout); + ValidateConfigRequest request1 = new ValidateConfigRequest(AnalysisType.AD, detector, typeStr, 1, 1, 1, requestTimeout, 10); // Test serialization BytesStreamOutput output = new BytesStreamOutput(); request1.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); - ValidateAnomalyDetectorRequest request2 = new ValidateAnomalyDetectorRequest(input); - assertEquals("serialization has the wrong detector", request2.getDetector(), detector); + ValidateConfigRequest request2 = new ValidateConfigRequest(input); + assertEquals("serialization has the wrong detector", request2.getConfig(), detector); assertEquals("serialization has the wrong typeStr", request2.getValidationType(), typeStr); assertEquals("serialization has the wrong requestTimeout", request2.getRequestTimeout(), requestTimeout); } diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java index 6c52634d0..0b28b67f7 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java @@ -16,12 +16,13 @@ import java.util.Map; import org.junit.Test; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.ConfigValidationIssue; +import org.opensearch.timeseries.transport.ValidateConfigResponse; public class ValidateAnomalyDetectorResponseTests extends AbstractTimeSeriesTest { @@ -30,22 +31,22 @@ public void testResponseSerialization() throws IOException { Map subIssues = new HashMap<>(); subIssues.put("a", "b"); subIssues.put("c", "d"); - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); + ValidateConfigResponse response = new ValidateConfigResponse(issue); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + ValidateConfigResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); assertEquals("serialization has the wrong issue", issue, readResponse.getIssue()); } @Test public void testResponseSerializationWithEmptyIssue() throws IOException { - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null); + ValidateConfigResponse response = new ValidateConfigResponse((ConfigValidationIssue) null); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + ValidateConfigResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); assertNull("serialization should have empty issue", readResponse.getIssue()); } @@ -53,8 +54,8 @@ public void testResponseToXContentWithSubIssues() throws IOException { Map subIssues = new HashMap<>(); subIssues.put("a", "b"); subIssues.put("c", "d"); - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); + ValidateConfigResponse response = new ValidateConfigResponse(issue); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); String message = issue.getMessage(); assertEquals( @@ -64,27 +65,27 @@ public void testResponseToXContentWithSubIssues() throws IOException { } public void testResponseToXContent() throws IOException { - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssue(); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssue(); + ValidateConfigResponse response = new ValidateConfigResponse(issue); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); String message = issue.getMessage(); assertEquals("{\"detector\":{\"name\":{\"message\":\"" + message + "\"}}}", validationResponse); } public void testResponseToXContentNull() throws IOException { - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null); + ValidateConfigResponse response = new ValidateConfigResponse((ConfigValidationIssue) null); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); assertEquals("{}", validationResponse); } public void testResponseToXContentWithIntervalRec() throws IOException { long intervalRec = 5; - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); + ValidateConfigResponse response = new ValidateConfigResponse(issue); String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); assertEquals( "{\"model\":{\"detection_interval\":{\"message\":\"" - + ADCommonMessages.DETECTOR_INTERVAL_REC + + CommonMessages.INTERVAL_REC + intervalRec + "\",\"suggested_value\":{\"period\":{\"interval\":5,\"unit\":\"Minutes\"}}}}}", validationResponse @@ -94,12 +95,12 @@ public void testResponseToXContentWithIntervalRec() throws IOException { @Test public void testResponseSerializationWithIntervalRec() throws IOException { long intervalRec = 5; - DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); - ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + ConfigValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); + ValidateConfigResponse response = new ValidateConfigResponse(issue); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + ValidateConfigResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); assertEquals(issue, readResponse.getIssue()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java index 604fc2c46..076ee60f3 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java @@ -25,12 +25,15 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.common.unit.TimeValue; import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.ValidateConfigRequest; +import org.opensearch.timeseries.transport.ValidateConfigResponse; import com.google.common.base.Charsets; import com.google.common.collect.ImmutableList; @@ -44,30 +47,34 @@ public void testValidateAnomalyDetectorWithNoIssue() throws IOException { AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(sumValueFeature(nameField, ipField + ".is_error", "test-2"))); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNull(response.getIssue()); } @Test public void testValidateAnomalyDetectorWithNoIndexFound() throws IOException { AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.INDICES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -80,15 +87,17 @@ public void testValidateAnomalyDetectorWithDuplicateName() throws IOException { ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); createDetectorIndex(); createDetector(anomalyDetector); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -99,15 +108,17 @@ public void testValidateAnomalyDetectorWithNonExistingFeatureField() throws IOEx Feature maxFeature = maxValueFeature(nameField, "non_existing_field", nameField); AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -123,15 +134,17 @@ public void testValidateAnomalyDetectorWithDuplicateFeatureAggregationNames() th AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertTrue(response.getIssue().getMessage().contains("Config has duplicate feature aggregation query names:")); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); @@ -145,15 +158,17 @@ public void testValidateAnomalyDetectorWithDuplicateFeatureNamesAndDuplicateAggr AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertTrue(response.getIssue().getMessage().contains("Config has duplicate feature aggregation query names:")); assertTrue(response.getIssue().getMessage().contains("There are duplicate feature names:")); @@ -168,15 +183,17 @@ public void testValidateAnomalyDetectorWithDuplicateFeatureNames() throws IOExce AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertTrue( "actual: " + response.getIssue().getMessage(), @@ -191,15 +208,17 @@ public void testValidateAnomalyDetectorWithInvalidFeatureField() throws IOExcept Feature maxFeature = maxValueFeature(nameField, categoryField, nameField); AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -218,15 +237,17 @@ public void testValidateAnomalyDetectorWithUnknownFeatureField() throws IOExcept ImmutableList.of(new Feature(randomAlphaOfLength(5), nameField, true, aggregationBuilder)) ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); @@ -241,15 +262,17 @@ public void testValidateAnomalyDetectorWithMultipleInvalidFeatureField() throws AnomalyDetector anomalyDetector = TestHelpers .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNotNull(response.getIssue()); assertEquals(response.getIssue().getSubIssues().keySet().size(), 2); assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); @@ -273,15 +296,17 @@ public void testValidateAnomalyDetectorWithCustomResultIndex() throws IOExceptio resultIndex ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNull(response.getIssue()); } @@ -311,15 +336,17 @@ public void testValidateAnomalyDetectorWithCustomResultIndexWithInvalidMapping() resultIndex ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.RESULT_INDEX, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertTrue(response.getIssue().getMessage().contains(CommonMessages.INVALID_RESULT_INDEX_MAPPING)); @@ -340,15 +367,17 @@ private void testValidateAnomalyDetectorWithCustomResultIndex(boolean resultInde resultIndex ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertNull(response.getIssue()); } @@ -372,18 +401,21 @@ public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOExcept null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertEquals(CommonMessages.INVALID_NAME, response.getIssue().getMessage()); @@ -409,18 +441,21 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept null, TestHelpers.randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertTrue(response.getIssue().getMessage().contains("Name should be shortened. The maximum limit is")); @@ -430,15 +465,17 @@ public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOExcept public void testValidateAnomalyDetectorWithNonExistentTimefield() throws IOException { AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.TIMEFIELD_FIELD, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertEquals( @@ -451,15 +488,17 @@ public void testValidateAnomalyDetectorWithNonExistentTimefield() throws IOExcep public void testValidateAnomalyDetectorWithNonDateTimeField() throws IOException { AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(categoryField, "index-test"); ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); - ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + ValidateConfigRequest request = new ValidateConfigRequest( + AnalysisType.AD, anomalyDetector, ValidationAspect.DETECTOR.getName(), 5, 5, 5, - new TimeValue(5_000L) + new TimeValue(5_000L), + 10 ); - ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + ValidateConfigResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); assertEquals(ValidationIssueType.TIMEFIELD_FIELD, response.getIssue().getType()); assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); assertEquals( diff --git a/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java index 12f966ffe..353611121 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java @@ -29,7 +29,6 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.transport.AnomalyResultTests; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -42,6 +41,7 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; public abstract class AbstractIndexHandlerTest extends AbstractTimeSeriesTest { enum IndexCreation { @@ -92,7 +92,7 @@ public void setUp() throws Exception { setWriteBlockAdResultIndex(false); context = TestHelpers.createThreadPool(); clientUtil = new ClientUtil(client); - indexUtil = new IndexUtils(client, clientUtil, clusterService, indexNameResolver); + indexUtil = new IndexUtils(clusterService, indexNameResolver); } protected void setWriteBlockAdResultIndex(boolean blocked) { diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java index 68699b74e..af3442433 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java @@ -34,9 +34,10 @@ import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -46,17 +47,20 @@ import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import com.google.common.collect.ImmutableList; public class AnomalyResultBulkIndexHandlerTests extends ADUnitTestCase { - private AnomalyResultBulkIndexHandler bulkIndexHandler; + private ResultBulkIndexingHandler bulkIndexHandler; private Client client; private IndexUtils indexUtils; private ActionListener listener; private ADIndexManagement anomalyDetectionIndices; + private String configId; @Override public void setUp() throws Exception { @@ -70,14 +74,17 @@ public void setUp() throws Exception { indexUtils = mock(IndexUtils.class); ClusterService clusterService = mock(ClusterService.class); ThreadPool threadPool = mock(ThreadPool.class); - bulkIndexHandler = new AnomalyResultBulkIndexHandler( + bulkIndexHandler = new ResultBulkIndexingHandler( client, settings, threadPool, + ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, clientUtil, indexUtils, clusterService, - anomalyDetectionIndices + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); listener = spy(new ActionListener() { @Override @@ -86,10 +93,11 @@ public void onResponse(BulkResponse bulkItemResponses) {} @Override public void onFailure(Exception e) {} }); + configId = "testId"; } public void testNullAnomalyResults() { - bulkIndexHandler.bulkIndexAnomalyResult(null, null, listener); + bulkIndexHandler.bulk(null, null, null, listener); verify(listener, times(1)).onResponse(null); verify(anomalyDetectionIndices, never()).doesConfigIndexExist(); } @@ -97,9 +105,9 @@ public void testNullAnomalyResults() { public void testAnomalyResultBulkIndexHandler_IndexNotExist() { when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(false); AnomalyResult anomalyResult = mock(AnomalyResult.class); - when(anomalyResult.getConfigId()).thenReturn("testId"); + when(anomalyResult.getConfigId()).thenReturn(configId); - bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + bulkIndexHandler.bulk("testIndex", ImmutableList.of(anomalyResult), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Can't find result index testIndex", exceptionCaptor.getValue().getMessage()); } @@ -108,9 +116,10 @@ public void testAnomalyResultBulkIndexHandler_InValidResultIndexMapping() { when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(true); when(anomalyDetectionIndices.isValidResultIndexMapping("testIndex")).thenReturn(false); AnomalyResult anomalyResult = mock(AnomalyResult.class); - when(anomalyResult.getConfigId()).thenReturn("testId"); - bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + when(anomalyResult.getConfigId()).thenReturn(configId); + + bulkIndexHandler.bulk("testIndex", ImmutableList.of(anomalyResult), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("wrong index mapping of custom AD result index", exceptionCaptor.getValue().getMessage()); } @@ -119,10 +128,10 @@ public void testAnomalyResultBulkIndexHandler_FailBulkIndexAnomaly() throws IOEx when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(true); when(anomalyDetectionIndices.isValidResultIndexMapping("testIndex")).thenReturn(true); AnomalyResult anomalyResult = mock(AnomalyResult.class); - when(anomalyResult.getConfigId()).thenReturn("testId"); + when(anomalyResult.getConfigId()).thenReturn(configId); when(anomalyResult.toXContent(any(), any())).thenThrow(new RuntimeException()); - bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + bulkIndexHandler.bulk("testIndex", ImmutableList.of(anomalyResult), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Failed to prepare request to bulk index anomaly results", exceptionCaptor.getValue().getMessage()); } @@ -133,7 +142,7 @@ public void testCreateADResultIndexNotAcknowledged() throws IOException { listener.onResponse(new CreateIndexResponse(false, false, ANOMALY_RESULT_INDEX_ALIAS)); return null; }).when(anomalyDetectionIndices).initDefaultResultIndexDirectly(any()); - bulkIndexHandler.bulkIndexAnomalyResult(null, ImmutableList.of(mock(AnomalyResult.class)), listener); + bulkIndexHandler.bulk(null, ImmutableList.of(mock(AnomalyResult.class)), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals("Creating anomaly result index with mappings call not acknowledged", exceptionCaptor.getValue().getMessage()); } @@ -166,8 +175,7 @@ public void testWrongAnomalyResult() { listener.onResponse(bulkResponse); return null; }).when(client).bulk(any(), any()); - bulkIndexHandler - .bulkIndexAnomalyResult(null, ImmutableList.of(wrongAnomalyResult(), TestHelpers.randomAnomalyDetectResult()), listener); + bulkIndexHandler.bulk(null, ImmutableList.of(wrongAnomalyResult(), TestHelpers.randomAnomalyDetectResult()), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertTrue(exceptionCaptor.getValue().getMessage().contains("VersionConflictEngineException")); } @@ -184,7 +192,7 @@ public void testBulkSaveException() { return null; }).when(client).bulk(any(), any()); - bulkIndexHandler.bulkIndexAnomalyResult(null, ImmutableList.of(TestHelpers.randomAnomalyDetectResult()), listener); + bulkIndexHandler.bulk(null, ImmutableList.of(TestHelpers.randomAnomalyDetectResult()), configId, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals(testError, exceptionCaptor.getValue().getMessage()); } diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java index b17008e1d..616fc0a51 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java @@ -34,7 +34,10 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; @@ -42,6 +45,7 @@ import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.transport.handler.ResultIndexingHandler; public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @Mock @@ -54,7 +58,7 @@ public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(AnomalyIndexHandler.class); + super.setUpLog4jForJUnit(ResultIndexingHandler.class); } @Override @@ -81,7 +85,7 @@ public void testSavingAdResult() throws IOException { listener.onResponse(mock(IndexResponse.class)); return null; }).when(client).index(any(IndexRequest.class), ArgumentMatchers.>any()); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -89,19 +93,21 @@ public void testSavingAdResult() throws IOException { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); } @Test public void testSavingFailureNotRetry() throws InterruptedException, IOException { savingFailureTemplate(false, 1, true); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.RETRY_SAVING_ERR_MSG, true)); } @Test @@ -109,15 +115,15 @@ public void testSavingFailureRetry() throws InterruptedException, IOException { setWriteBlockAdResultIndex(false); savingFailureTemplate(true, 3, true); - assertEquals(2, testAppender.countMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + assertEquals(2, testAppender.countMessage(ResultIndexingHandler.RETRY_SAVING_ERR_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); } @Test public void testIndexWriteBlock() { setWriteBlockAdResultIndex(true); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -125,17 +131,19 @@ public void testIndexWriteBlock() { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); - assertTrue(testAppender.containsMessage(AnomalyIndexHandler.CANNOT_SAVE_ERR_MSG, true)); + assertTrue(testAppender.containsMessage(ResultIndexingHandler.CANNOT_SAVE_ERR_MSG, true)); } @Test public void testAdResultIndexExist() throws IOException { setUpSavingAnomalyResultIndex(false, IndexCreation.RESOURCE_EXISTS_EXCEPTION); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -143,7 +151,9 @@ public void testAdResultIndexExist() throws IOException { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); verify(client, times(1)).index(any(), any()); @@ -155,7 +165,7 @@ public void testAdResultIndexOtherException() throws IOException { expectedEx.expectMessage("Error in saving .opendistro-anomaly-results for detector " + detectorId); setUpSavingAnomalyResultIndex(false, IndexCreation.RUNTIME_EXCEPTION); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, settings, threadPool, @@ -163,7 +173,9 @@ public void testAdResultIndexOtherException() throws IOException { anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); verify(client, never()).index(any(), any()); @@ -213,7 +225,7 @@ private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionExcep .put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(1)) .build(); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler<>( client, backoffSettings, threadPool, @@ -221,7 +233,9 @@ private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionExcep anomalyDetectionIndices, clientUtil, indexUtil, - clusterService + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); diff --git a/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java index f6483c8b7..fa7e3acdb 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java @@ -23,36 +23,29 @@ import org.junit.Test; import org.mockito.ArgumentMatchers; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.ad.transport.ADResultBulkAction; import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.transport.ResultBulkResponse; public class MultiEntityResultHandlerTests extends AbstractIndexHandlerTest { - private MultiEntityResultHandler handler; + private ADIndexMemoryPressureAwareResultHandler handler; private ADResultBulkRequest request; - private ADResultBulkResponse response; + private ResultBulkResponse response; @Override public void setUp() throws Exception { super.setUp(); - handler = new MultiEntityResultHandler( - client, - settings, - threadPool, - anomalyDetectionIndices, - clientUtil, - indexUtil, - clusterService - ); + handler = new ADIndexMemoryPressureAwareResultHandler(client, anomalyDetectionIndices); request = new ADResultBulkRequest(); - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -61,15 +54,15 @@ public void setUp() throws Exception { ); request.add(resultWriteRequest); - response = new ADResultBulkResponse(); + response = new ResultBulkResponse(); - super.setUpLog4jForJUnit(MultiEntityResultHandler.class); + super.setUpLog4jForJUnit(ADIndexMemoryPressureAwareResultHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onResponse(response); return null; - }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); } @Override @@ -89,10 +82,7 @@ public void testIndexWriteBlock() throws InterruptedException { verified.countDown(); }, exception -> { assertTrue(exception instanceof TimeSeriesException); - assertTrue( - "actual: " + exception.getMessage(), - exception.getMessage().contains(MultiEntityResultHandler.CANNOT_SAVE_RESULT_ERR_MSG) - ); + assertTrue("actual: " + exception.getMessage(), exception.getMessage().contains(CommonMessages.CANNOT_SAVE_RESULT_ERR_MSG)); verified.countDown(); })); @@ -109,17 +99,17 @@ public void testSavingAdResult() throws IOException, InterruptedException { verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } @Test public void testSavingFailure() throws IOException, InterruptedException { setUpSavingAnomalyResultIndex(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException()); return null; - }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); CountDownLatch verified = new CountDownLatch(1); handler.flush(request, ActionListener.wrap(response -> { @@ -142,7 +132,7 @@ public void testAdResultIndexExists() throws IOException, InterruptedException { verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } @Test @@ -200,6 +190,6 @@ public void testCreateResourcExistsException() throws IOException, InterruptedEx verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } } diff --git a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java index aadc2d999..5a5e35e81 100644 --- a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java +++ b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java @@ -25,6 +25,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.util.BulkUtil; public class BulkUtilTests extends OpenSearchTestCase { public void testGetFailedIndexRequest() { diff --git a/src/test/java/org/opensearch/ad/util/DateUtilsTests.java b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java index 593445b01..0a5a1fb40 100644 --- a/src/test/java/org/opensearch/ad/util/DateUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.util.DateUtils; public class DateUtilsTests extends OpenSearchTestCase { public void testDuration() { diff --git a/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java index 7234f6feb..cbbffb869 100644 --- a/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java @@ -20,6 +20,7 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; public class IndexUtilsTests extends OpenSearchIntegTestCase { @@ -36,7 +37,7 @@ public void setup() { @Test public void testGetIndexHealth_NoIndex() { - IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); String output = indexUtils.getIndexHealthStatus("test"); assertEquals(IndexUtils.NONEXISTENT_INDEX_STATUS, output); } @@ -46,7 +47,7 @@ public void testGetIndexHealth_Index() { String indexName = "test-2"; createIndex(indexName); flush(); - IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); String status = indexUtils.getIndexHealthStatus(indexName); assertTrue(status.equals("green") || status.equals("yellow")); } @@ -59,7 +60,7 @@ public void testGetIndexHealth_Alias() { flush(); AcknowledgedResponse response = client().admin().indices().prepareAliases().addAlias(indexName, aliasName).execute().actionGet(); assertTrue(response.isAcknowledged()); - IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); String status = indexUtils.getIndexHealthStatus(aliasName); assertTrue(status.equals("green") || status.equals("yellow")); } diff --git a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java index c2dd673b4..af919c1cd 100644 --- a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java @@ -11,7 +11,6 @@ package org.opensearch.ad.util; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; import static org.opensearch.timeseries.util.ParseUtils.isAdmin; import java.io.IOException; @@ -127,16 +126,17 @@ public void testGenerateInternalFeatureQuery() throws IOException { public void testAddUserRoleFilterWithNullUser() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter(null, searchSourceBuilder); + ParseUtils.addUserBackendRolesFilter(null, searchSourceBuilder); assertEquals("{}", searchSourceBuilder.toString()); } public void testAddUserRoleFilterWithNullUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter( - new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," @@ -147,15 +147,16 @@ public void testAddUserRoleFilterWithNullUserBackendRole() { public void testAddUserRoleFilterWithEmptyUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter( - new User( - randomAlphaOfLength(5), - ImmutableList.of(), - ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(randomAlphaOfLength(5)) - ), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," @@ -168,15 +169,16 @@ public void testAddUserRoleFilterWithNormalUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); String backendRole1 = randomAlphaOfLength(5); String backendRole2 = randomAlphaOfLength(5); - addUserBackendRolesFilter( - new User( - randomAlphaOfLength(5), - ImmutableList.of(backendRole1, backendRole2), - ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(randomAlphaOfLength(5)) - ), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(backendRole1, backendRole2), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":" + "[\"" diff --git a/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java b/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java index 28ec18bab..b097d67d8 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecastTaskSerializationTests.java @@ -49,7 +49,7 @@ public void testConstructor_allFieldsPresent() throws IOException { assertEquals("task123", readTask.getTaskId()); assertEquals("FORECAST_HISTORICAL_HC_ENTITY", readTask.getTaskType()); - assertTrue(readTask.isEntityTask()); + assertTrue(readTask.isHistoricalEntityTask()); assertEquals("config123", readTask.getConfigId()); assertEquals(originalTask.getForecaster(), readTask.getForecaster()); assertEquals("Running", readTask.getState()); @@ -93,7 +93,7 @@ public void testConstructor_missingOptionalFields() throws IOException { assertEquals("task123", readTask.getTaskId()); assertEquals("FORECAST_HISTORICAL_HC_ENTITY", readTask.getTaskType()); - assertTrue(readTask.isEntityTask()); + assertTrue(readTask.isHistoricalEntityTask()); assertEquals("config123", readTask.getConfigId()); assertEquals(null, readTask.getForecaster()); assertEquals("Running", readTask.getState()); diff --git a/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java b/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java index 4ee403a0e..db309f886 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecastTaskTypeTests.java @@ -11,28 +11,16 @@ public class ForecastTaskTypeTests extends OpenSearchTestCase { - public void testHistoricalForecasterTaskTypes() { + public void testRunOnceForecasterTaskTypes() { assertEquals( - Arrays.asList(ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM), - ForecastTaskType.HISTORICAL_FORECASTER_TASK_TYPES - ); - } - - public void testAllHistoricalTaskTypes() { - assertEquals( - Arrays - .asList( - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY - ), - ForecastTaskType.ALL_HISTORICAL_TASK_TYPES + Arrays.asList(ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER, ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM), + ForecastTaskType.RUN_ONCE_TASK_TYPES ); } public void testRealtimeTaskTypes() { assertEquals( - Arrays.asList(ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER), + Arrays.asList(ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM, ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER), ForecastTaskType.REALTIME_TASK_TYPES ); } @@ -41,11 +29,10 @@ public void testAllForecastTaskTypes() { assertEquals( Arrays .asList( - ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, - ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, - ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, - ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ForecastTaskType.REALTIME_FORECAST_SINGLE_STREAM, + ForecastTaskType.REALTIME_FORECAST_HC_FORECASTER, + ForecastTaskType.RUN_ONCE_FORECAST_HC_FORECASTER, + ForecastTaskType.RUN_ONCE_FORECAST_SINGLE_STREAM ), ForecastTaskType.ALL_FORECAST_TASK_TYPES ); diff --git a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java index 0b64912bf..a5769226e 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java @@ -54,6 +54,7 @@ public class ForecasterTests extends AbstractTimeSeriesTest { User user = new User("testUser", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); String resultIndex = null; Integer horizon = 1; + double transformDecay = 0.05d; public void testForecasterConstructor() { ImputationOption imputationOption = TestHelpers.randomImputationOption(); @@ -77,7 +78,8 @@ public void testForecasterConstructor() { user, resultIndex, horizon, - imputationOption + imputationOption, + transformDecay ); assertEquals(forecasterId, forecaster.getId()); @@ -124,7 +126,8 @@ public void testForecasterConstructorWithNullForecastInterval() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + transformDecay ); }); @@ -156,7 +159,8 @@ public void testNegativeInterval() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + transformDecay ); }); @@ -188,7 +192,8 @@ public void testMaxCategoryFieldsLimits() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + transformDecay ); }); @@ -220,7 +225,8 @@ public void testBlankName() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + transformDecay ); }); @@ -252,7 +258,8 @@ public void testInvalidCustomResultIndex() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + transformDecay ); }); @@ -283,7 +290,8 @@ public void testValidCustomResultIndex() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + transformDecay ); assertEquals(resultIndex, forecaster.getCustomResultIndex()); @@ -312,7 +320,8 @@ public void testInvalidHorizon() { user, resultIndex, horizon, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + transformDecay ); }); diff --git a/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java index dda3a8761..28dd54e51 100644 --- a/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java +++ b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java @@ -15,16 +15,4 @@ public void testIsForecastEnabled() { assertTrue(!ForecastEnabledSetting.isForecastEnabled()); } - public void testIsForecastBreakerEnabled() { - assertTrue(ForecastEnabledSetting.isForecastBreakerEnabled()); - ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED, false); - assertTrue(!ForecastEnabledSetting.isForecastBreakerEnabled()); - } - - public void testIsDoorKeeperInCacheEnabled() { - assertTrue(!ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); - ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, true); - assertTrue(ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); - } - } diff --git a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java index 7f024ef6d..b3e3b7150 100644 --- a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java +++ b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java @@ -28,13 +28,12 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.ADTaskProfileRunner; import org.opensearch.ad.AbstractProfileRunnerTests; import org.opensearch.ad.AnomalyDetectorProfileRunner; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileNodeResponse; -import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.ADProfileAction; import org.opensearch.cluster.ClusterName; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; @@ -48,6 +47,8 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.ProfileNodeResponse; +import org.opensearch.timeseries.transport.ProfileResponse; import org.opensearch.timeseries.util.SecurityClientUtil; /** @@ -85,7 +86,8 @@ private void setUpMultiEntityClientGet(DetectorStatus detectorStatus, JobStatus nodeFilter, requiredSamples, transportService, - adTaskManager + adTaskManager, + mock(ADTaskProfileRunner.class) ); doAnswer(invocation -> { @@ -169,7 +171,7 @@ private void setUpMultiEntityClientSearch(ADResultStatus resultStatus, Cardinali for (int i = 0; i < 100; i++) { hyperLogLog.collect(0, BitMixer.mix64(randomIntBetween(1, 100))); } - aggs.add(new InternalCardinality(ADCommonName.TOTAL_ENTITIES, hyperLogLog, new HashMap<>())); + aggs.add(new InternalCardinality(CommonName.TOTAL_ENTITIES, hyperLogLog, new HashMap<>())); when(response.getAggregations()).thenReturn(InternalAggregations.from(aggs)); listener.onResponse(response); break; @@ -204,7 +206,7 @@ private void setUpProfileAction() { listener.onResponse(new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, Collections.emptyList())); return null; - }).when(client).execute(eq(ProfileAction.INSTANCE), any(), any()); + }).when(client).execute(eq(ADProfileAction.INSTANCE), any(), any()); } public void testFailGetEntityStats() throws IOException, InterruptedException { diff --git a/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java b/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java index e52255818..7b196a9af 100644 --- a/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java +++ b/src/test/java/org/opensearch/timeseries/NodeStateManagerTests.java @@ -192,9 +192,9 @@ private void setupCheckpoint(boolean responseExists) throws IOException { doAnswer(invocation -> { Object[] args = invocation.getArguments(); assertTrue( - String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), - args.length >= 2 - ); + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); GetRequest request = null; ActionListener listener = null; diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 23a5150cc..bc7e76252 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -58,10 +58,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.constant.CommonValue; -import org.opensearch.ad.feature.Features; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.mock.model.MockSimpleLog; @@ -72,10 +69,8 @@ import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.AnomalyResultBucket; import org.opensearch.ad.model.DetectorInternalState; -import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.model.ExpectedValueList; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.Request; @@ -133,10 +128,14 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; import org.opensearch.timeseries.dataprocessor.ImputationMethod; import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ConfigValidationIssue; import org.opensearch.timeseries.model.DataByFeatureId; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; @@ -148,6 +147,7 @@ import org.opensearch.timeseries.model.TimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.ratelimit.RequestPriority; import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; @@ -321,7 +321,8 @@ public static AnomalyDetector randomAnomalyDetector( categoryFields, user, null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -366,7 +367,8 @@ public static AnomalyDetector randomDetector( categoryFields, null, resultIndex, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -421,7 +423,8 @@ public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( categoryFields, randomUser(), resultIndex, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -452,7 +455,8 @@ public static AnomalyDetector randomAnomalyDetector(String timefield, String ind null, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -475,7 +479,8 @@ public static AnomalyDetector randomAnomalyDetectorWithEmptyFeature() throws IOE null, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -503,7 +508,8 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio categoryField, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -650,7 +656,8 @@ public AnomalyDetector build() { categoryFields, user, resultIndex, - imputationOption + imputationOption, + randomDouble() ); } } @@ -676,7 +683,8 @@ public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguratio categoryField, randomUser(), null, - TestHelpers.randomImputationOption() + TestHelpers.randomImputationOption(), + randomDouble() ); } @@ -888,8 +896,8 @@ public static AnomalyResult randomHCADAnomalyDetectResult(double score, double g return randomHCADAnomalyDetectResult(score, grade, null); } - public static ResultWriteRequest randomResultWriteRequest(String detectorId, double score, double grade) { - ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + public static ADResultWriteRequest randomADResultWriteRequest(String detectorId, double score, double grade) { + ADResultWriteRequest resultWriteRequest = new ADResultWriteRequest( Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), detectorId, RequestPriority.MEDIUM, @@ -980,7 +988,8 @@ public static Job randomAnomalyDetectorJob(boolean enabled, Instant enabledTime, Instant.now().truncatedTo(ChronoUnit.SECONDS), 60L, randomUser(), - null + null, + AnalysisType.AD ); } @@ -1187,6 +1196,15 @@ public static GetResponse createBrokenGetResponse(String id, String indexName) t ); } + public static GetResponse createGetResponse(Map source, String id, String indexName) throws IOException { + XContentBuilder xContent = XContentFactory.jsonBuilder(); + xContent.map(source); + BytesReference documentSource = BytesReference.bytes(xContent); + return new GetResponse( + new GetResult(indexName, id, UNASSIGNED_SEQ_NO, 0, -1, true, documentSource, Collections.emptyMap(), Collections.emptyMap()) + ); + } + public static SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); @@ -1517,8 +1535,8 @@ public static Map parseStatsResult(String statsResult) throws IO return adStats; } - public static DetectorValidationIssue randomDetectorValidationIssue() { - DetectorValidationIssue issue = new DetectorValidationIssue( + public static ConfigValidationIssue randomDetectorValidationIssue() { + ConfigValidationIssue issue = new ConfigValidationIssue( ValidationAspect.DETECTOR, ValidationIssueType.NAME, randomAlphaOfLength(5) @@ -1526,8 +1544,8 @@ public static DetectorValidationIssue randomDetectorValidationIssue() { return issue; } - public static DetectorValidationIssue randomDetectorValidationIssueWithSubIssues(Map subIssues) { - DetectorValidationIssue issue = new DetectorValidationIssue( + public static ConfigValidationIssue randomDetectorValidationIssueWithSubIssues(Map subIssues) { + ConfigValidationIssue issue = new ConfigValidationIssue( ValidationAspect.DETECTOR, ValidationIssueType.NAME, randomAlphaOfLength(5), @@ -1537,11 +1555,11 @@ public static DetectorValidationIssue randomDetectorValidationIssueWithSubIssues return issue; } - public static DetectorValidationIssue randomDetectorValidationIssueWithDetectorIntervalRec(long intervalRec) { - DetectorValidationIssue issue = new DetectorValidationIssue( + public static ConfigValidationIssue randomDetectorValidationIssueWithDetectorIntervalRec(long intervalRec) { + ConfigValidationIssue issue = new ConfigValidationIssue( ValidationAspect.MODEL, ValidationIssueType.DETECTION_INTERVAL, - ADCommonMessages.DETECTOR_INTERVAL_REC + intervalRec, + CommonMessages.INTERVAL_REC + intervalRec, null, new IntervalTimeConfiguration(intervalRec, ChronoUnit.MINUTES) ); @@ -1757,7 +1775,8 @@ public Forecaster build() { user, resultIndex, horizon, - imputationOption + imputationOption, + randomDouble() ); } } @@ -1782,7 +1801,8 @@ public static Forecaster randomForecaster() throws IOException { randomUser(), null, randomIntBetween(1, 20), - randomImputationOption() + randomImputationOption(), + randomDouble() ); } diff --git a/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java b/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java index 9731d31b5..ad907c448 100644 --- a/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/timeseries/feature/SearchFeatureDaoTests.java @@ -263,7 +263,7 @@ public void getLatestDataTime_returnExpectedToListener() { when(searchResponse.getAggregations()).thenReturn(internalAggregations); ActionListener> listener = mock(ActionListener.class); - searchFeatureDao.getLatestDataTime(detector, listener); + searchFeatureDao.getLatestDataTime(detector, Optional.empty(), AnalysisType.AD, listener); ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(captor.capture()); diff --git a/src/test/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSettingTests.java b/src/test/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSettingTests.java new file mode 100644 index 000000000..ae107c7e9 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/settings/TimeSeriesEnabledSettingTests.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.settings; + +import org.opensearch.test.OpenSearchTestCase; + +public class TimeSeriesEnabledSettingTests extends OpenSearchTestCase { + public void testIsForecastBreakerEnabled() { + assertTrue(TimeSeriesEnabledSetting.isBreakerEnabled()); + TimeSeriesEnabledSetting.getInstance().setSettingValue(TimeSeriesEnabledSetting.BREAKER_ENABLED, false); + assertTrue(!TimeSeriesEnabledSetting.isBreakerEnabled()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/ADResultBulkTransportActionTests.java similarity index 85% rename from src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/ADResultBulkTransportActionTests.java index 9887f1aff..e497988fc 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/ADResultBulkTransportActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; @@ -32,6 +32,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkTransportAction; +import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -41,6 +44,7 @@ import org.opensearch.index.IndexingPressure; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.transport.TransportService; public class ADResultBulkTransportActionTests extends AbstractTimeSeriesTest { @@ -98,8 +102,8 @@ public void testSendAll() { when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); ADResultBulkRequest originalRequest = new ADResultBulkRequest(); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -118,7 +122,7 @@ public void testSendAll() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -131,8 +135,8 @@ public void testSendPartial() { when(indexingPressure.getCurrentReplicaBytes()).thenReturn(24L); ADResultBulkRequest originalRequest = new ADResultBulkRequest(); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -151,7 +155,7 @@ public void testSendPartial() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -165,10 +169,10 @@ public void testSendRandomPartial() { ADResultBulkRequest originalRequest = new ADResultBulkRequest(); for (int i = 0; i < 1000; i++) { - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); } - originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + originalRequest.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -190,7 +194,7 @@ public void testSendRandomPartial() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -198,8 +202,8 @@ public void testSendRandomPartial() { public void testSerialzationRequest() throws IOException { ADResultBulkRequest request = new ADResultBulkRequest(); - request.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); - request.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + request.add(TestHelpers.randomADResultWriteRequest(detectorId, 0.8d, 0d)); + request.add(TestHelpers.randomADResultWriteRequest(detectorId, 8d, 0.2d)); BytesStreamOutput output = new BytesStreamOutput(); request.writeTo(output); @@ -210,6 +214,6 @@ public void testSerialzationRequest() throws IOException { public void testValidateRequest() { ActionRequestValidationException e = new ADResultBulkRequest().validate(); - assertThat(e.validationErrors(), hasItem(ADResultBulkRequest.NO_REQUESTS_ADDED_ERR)); + assertThat(e.validationErrors(), hasItem(CommonMessages.NO_REQUESTS_ADDED_ERR)); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/ADStatsNodesTransportActionTests.java similarity index 68% rename from src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/ADStatsNodesTransportActionTests.java index 2284c311e..8bcc0163d 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/ADStatsNodesTransportActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -26,18 +26,13 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.InternalStatNames; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.ad.stats.suppliers.ADModelsOnNodeSupplier; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.transport.ADStatsNodesTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -47,14 +42,19 @@ import org.opensearch.monitor.jvm.JvmStats; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.stats.InternalStatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.transport.TransportService; public class ADStatsNodesTransportActionTests extends OpenSearchIntegTestCase { private ADStatsNodesTransportAction action; private ADStats adStats; - private Map> statsMap; + private Map> statsMap; private String clusterStatName1, clusterStatName2; private String nodeStatName1, nodeStatName2; private ADTaskManager adTaskManager; @@ -68,10 +68,10 @@ public void setUp() throws Exception { Clock clock = mock(Clock.class); ThreadPool threadPool = mock(ThreadPool.class); IndexNameExpressionResolver indexNameResolver = mock(IndexNameExpressionResolver.class); - IndexUtils indexUtils = new IndexUtils(client, new ClientUtil(client), clusterService(), indexNameResolver); - ModelManager modelManager = mock(ModelManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); - EntityCache cache = mock(EntityCache.class); + IndexUtils indexUtils = new IndexUtils(clusterService(), indexNameResolver); + ADModelManager modelManager = mock(ADModelManager.class); + ADCacheProvider cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache cache = mock(ADPriorityCache.class); when(cacheProvider.get()).thenReturn(cache); clusterStatName1 = "clusterStat1"; @@ -87,13 +87,16 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - statsMap = new HashMap>() { + statsMap = new HashMap>() { { - put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); - put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); - put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); - put(InternalStatNames.JVM_HEAP_USAGE.getName(), new ADStat<>(true, new SettableSupplier())); + put(nodeStatName1, new TimeSeriesStat<>(false, new CounterSupplier())); + put( + nodeStatName2, + new TimeSeriesStat<>(false, new ADModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + ); + put(clusterStatName1, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + put(InternalStatNames.JVM_HEAP_USAGE.getName(), new TimeSeriesStat<>(true, new SettableSupplier())); } }; @@ -121,10 +124,10 @@ public void setUp() throws Exception { @Test public void testNewNodeRequest() { String nodeId = "nodeId1"; - ADStatsRequest adStatsRequest = new ADStatsRequest(nodeId); + StatsRequest adStatsRequest = new StatsRequest(nodeId); - ADStatsNodeRequest adStatsNodeRequest1 = new ADStatsNodeRequest(adStatsRequest); - ADStatsNodeRequest adStatsNodeRequest2 = action.newNodeRequest(adStatsRequest); + StatsNodeRequest adStatsNodeRequest1 = new StatsNodeRequest(adStatsRequest); + StatsNodeRequest adStatsNodeRequest2 = action.newNodeRequest(adStatsRequest); assertEquals(adStatsNodeRequest1.getADStatsRequest(), adStatsNodeRequest2.getADStatsRequest()); } @@ -132,7 +135,7 @@ public void testNewNodeRequest() { @Test public void testNodeOperation() { String nodeId = clusterService().localNode().getId(); - ADStatsRequest adStatsRequest = new ADStatsRequest((nodeId)); + StatsRequest adStatsRequest = new StatsRequest((nodeId)); adStatsRequest.clear(); Set statsToBeRetrieved = new HashSet<>(Arrays.asList(nodeStatName1, nodeStatName2)); @@ -141,7 +144,7 @@ public void testNodeOperation() { adStatsRequest.addStat(stat); } - ADStatsNodeResponse response = action.nodeOperation(new ADStatsNodeRequest(adStatsRequest)); + StatsNodeResponse response = action.nodeOperation(new StatsNodeRequest(adStatsRequest)); Map stats = response.getStatsMap(); @@ -154,7 +157,7 @@ public void testNodeOperation() { @Test public void testNodeOperationWithJvmHeapUsage() { String nodeId = clusterService().localNode().getId(); - ADStatsRequest adStatsRequest = new ADStatsRequest((nodeId)); + StatsRequest adStatsRequest = new StatsRequest((nodeId)); adStatsRequest.clear(); Set statsToBeRetrieved = new HashSet<>(Arrays.asList(nodeStatName1, InternalStatNames.JVM_HEAP_USAGE.getName())); @@ -163,7 +166,7 @@ public void testNodeOperationWithJvmHeapUsage() { adStatsRequest.addStat(stat); } - ADStatsNodeResponse response = action.nodeOperation(new ADStatsNodeRequest(adStatsRequest)); + StatsNodeResponse response = action.nodeOperation(new StatsNodeRequest(adStatsRequest)); Map stats = response.getStatsMap(); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java similarity index 75% rename from src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java index 42c09d44f..67fd5b8cf 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -24,10 +24,10 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; +import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -37,7 +37,6 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; @@ -47,8 +46,8 @@ public class AnomalyDetectorJobActionTests extends OpenSearchIntegTestCase { private AnomalyDetectorJobTransportAction action; private Task task; - private AnomalyDetectorJobRequest request; - private ActionListener response; + private JobRequest request; + private ActionListener response; @Override @Before @@ -75,16 +74,14 @@ public void setUp() throws Exception { client, clusterService, indexSettings(), - mock(ADIndexManagement.class), xContentRegistry(), - mock(ADTaskManager.class), - mock(ExecuteADResultResponseRecorder.class) + mock(ADIndexJobActionHandler.class) ); task = mock(Task.class); - request = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_start"); - response = new ActionListener() { + request = new JobRequest("1234", new DateRange(Instant.ofEpochMilli(4567), Instant.ofEpochMilli(7890)), true, "_start"); + response = new ActionListener() { @Override - public void onResponse(AnomalyDetectorJobResponse adResponse) { + public void onResponse(JobResponse adResponse) { // Will not be called as there is no detector Assert.assertTrue(false); } @@ -104,7 +101,12 @@ public void testStartAdJobTransportAction() { @Test public void testStopAdJobTransportAction() { - AnomalyDetectorJobRequest stopRequest = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_stop"); + JobRequest stopRequest = new JobRequest( + "1234", + new DateRange(Instant.ofEpochMilli(4567), Instant.ofEpochMilli(7890)), + true, + "_stop" + ); action.doExecute(task, stopRequest, response); } @@ -117,13 +119,13 @@ public void testAdJobAction() { @Test public void testAdJobRequest() throws IOException { DateRange detectionDateRange = new DateRange(Instant.MIN, Instant.now()); - request = new AnomalyDetectorJobRequest("1234", detectionDateRange, false, 4567, 7890, "_start"); + request = new JobRequest("1234", detectionDateRange, false, "_start"); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + JobRequest newRequest = new JobRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } @Test @@ -131,17 +133,17 @@ public void testAdJobRequest_NullDetectionDateRange() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + JobRequest newRequest = new JobRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } @Test public void testAdJobResponse() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); - AnomalyDetectorJobResponse response = new AnomalyDetectorJobResponse("1234", 45, 67, 890, RestStatus.OK); + JobResponse response = new JobResponse("1234"); response.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobResponse newResponse = new AnomalyDetectorJobResponse(input); + JobResponse newResponse = new JobResponse(input); Assert.assertEquals(response.getId(), newResponse.getId()); } } diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java similarity index 81% rename from src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java index 50765deb8..3957aa3bc 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/AnomalyDetectorJobTransportActionTests.java @@ -9,16 +9,13 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; @@ -44,16 +41,21 @@ import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.mock.transport.MockAnomalyDetectorJobAction; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.StatsAnomalyDetectorAction; import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TaskProfile; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Job; @@ -97,7 +99,7 @@ protected Settings nodeSettings(int nodeOrdinal) { public void testDetectorIndexNotFound() { deleteDetectorIndex(); String detectorId = randomAlphaOfLength(5); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); IndexNotFoundException exception = expectThrows( IndexNotFoundException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(3000) @@ -107,12 +109,12 @@ public void testDetectorIndexNotFound() { public void testDetectorNotFound() { String detectorId = randomAlphaOfLength(5); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); OpenSearchStatusException exception = expectThrows( OpenSearchStatusException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) ); - assertTrue(exception.getMessage().contains(FAIL_TO_FIND_CONFIG_MSG)); + assertTrue(exception.getMessage().contains(CommonMessages.FAIL_TO_FIND_CONFIG_MSG)); } public void testValidHistoricalAnalysis() throws IOException, InterruptedException { @@ -126,17 +128,10 @@ public void testStartHistoricalAnalysisWithUser() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { - AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); + JobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); ADTask adTask = getADTask(response.getId()); assertNotNull(adTask.getStartedBy()); assertNotNull(adTask.getUser()); @@ -155,18 +150,11 @@ public void testStartHistoricalAnalysisForSingleCategoryHCWithUser() throws IOEx ImmutableList.of(categoryField) ); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { - AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); + JobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); waitUntil(() -> { try { ADTask task = getADTask(response.getId()); @@ -207,18 +195,11 @@ public void testStartHistoricalAnalysisForMultiCategoryHCWithUser() throws IOExc ImmutableList.of(categoryField, ipField) ); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { - AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100_000); + JobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100_000); String taskId = response.getId(); waitUntil(() -> { @@ -252,8 +233,8 @@ public void testRunMultipleTasksForHistoricalAnalysis() throws IOException, Inte AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); assertNotNull(response.getId()); OpenSearchStatusException exception = null; // Add retry to solve the flaky test @@ -282,14 +263,7 @@ public void testRaceConditionByStartingMultipleTasks() throws IOException, Inter AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); client().execute(AnomalyDetectorJobAction.INSTANCE, request); client().execute(AnomalyDetectorJobAction.INSTANCE, request); @@ -317,16 +291,9 @@ public void testCleanOldTaskDocs() throws InterruptedException, IOException { long count = countDocs(ADCommonName.DETECTION_STATE_INDEX); assertEquals(states.size(), count); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - randomLong(), - randomLong(), - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, START_JOB); - AtomicReference response = new AtomicReference<>(); + AtomicReference response = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); Thread.sleep(2000); client().execute(AnomalyDetectorJobAction.INSTANCE, request, ActionListener.wrap(r -> { @@ -368,8 +335,8 @@ private List startRealtimeDetector() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, null); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = startDetectorJobRequest(detectorId, null); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); String jobId = response.getId(); assertEquals(detectorId, jobId); return ImmutableList.of(detectorId, jobId); @@ -399,7 +366,7 @@ public void testHistoricalDetectorWithoutEnabledFeature() throws IOException { private void testInvalidDetector(AnomalyDetector detector, String error) throws IOException { String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); OpenSearchStatusException exception = expectThrows( OpenSearchStatusException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) @@ -407,12 +374,12 @@ private void testInvalidDetector(AnomalyDetector detector, String error) throws assertEquals(error, exception.getMessage()); } - private AnomalyDetectorJobRequest startDetectorJobRequest(String detectorId, DateRange dateRange) { - return new AnomalyDetectorJobRequest(detectorId, dateRange, false, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); + private JobRequest startDetectorJobRequest(String detectorId, DateRange dateRange) { + return new JobRequest(detectorId, dateRange, false, START_JOB); } - private AnomalyDetectorJobRequest stopDetectorJobRequest(String detectorId, boolean historical) { - return new AnomalyDetectorJobRequest(detectorId, null, historical, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, STOP_JOB); + private JobRequest stopDetectorJobRequest(String detectorId, boolean historical) { + return new JobRequest(detectorId, null, historical, STOP_JOB); } public void testStopRealtimeDetector() throws IOException { @@ -420,7 +387,7 @@ public void testStopRealtimeDetector() throws IOException { String detectorId = realtimeResult.get(0); String jobId = realtimeResult.get(1); - AnomalyDetectorJobRequest request = stopDetectorJobRequest(detectorId, false); + JobRequest request = stopDetectorJobRequest(detectorId, false); client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); GetResponse doc = getDoc(CommonName.JOB_INDEX, detectorId); Job job = toADJob(doc); @@ -447,7 +414,7 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio if (taskRunning) { // It's possible that the task not started on worker node yet. Recancel it to make sure // task cancelled. - AnomalyDetectorJobRequest request = stopDetectorJobRequest(adTask.getConfigId(), true); + JobRequest request = stopDetectorJobRequest(adTask.getConfigId(), true); client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); } return !taskRunning; @@ -462,9 +429,9 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio public void testProfileHistoricalDetector() throws IOException, InterruptedException { ADTask adTask = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request = taskProfileRequest(adTask.getConfigId()); + GetConfigRequest request = taskProfileRequest(adTask.getConfigId()); GetAnomalyDetectorResponse response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); - assertTrue(response.getDetectorProfile().getAdTaskProfile() != null); + assertTrue(response.getDetectorProfile().getTaskProfile() != null); ADTask finishedTask = getADTask(adTask.getTaskId()); int i = 0; @@ -476,8 +443,8 @@ public void testProfileHistoricalDetector() throws IOException, InterruptedExcep assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(finishedTask.getState())); response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); - assertNull(response.getDetectorProfile().getAdTaskProfile().getNodeId()); - ADTask profileAdTask = response.getDetectorProfile().getAdTaskProfile().getAdTask(); + assertNull(response.getDetectorProfile().getTaskProfile().getNodeId()); + ADTask profileAdTask = response.getDetectorProfile().getTaskProfile().getTask(); assertEquals(finishedTask.getTaskId(), profileAdTask.getTaskId()); assertEquals(finishedTask.getConfigId(), profileAdTask.getConfigId()); assertEquals(finishedTask.getDetector(), profileAdTask.getDetector()); @@ -488,28 +455,28 @@ public void testProfileWithMultipleRunningTask() throws IOException { ADTask adTask1 = startHistoricalAnalysis(startTime, endTime); ADTask adTask2 = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request1 = taskProfileRequest(adTask1.getConfigId()); - GetAnomalyDetectorRequest request2 = taskProfileRequest(adTask2.getConfigId()); + GetConfigRequest request1 = taskProfileRequest(adTask1.getConfigId()); + GetConfigRequest request2 = taskProfileRequest(adTask2.getConfigId()); GetAnomalyDetectorResponse response1 = client().execute(GetAnomalyDetectorAction.INSTANCE, request1).actionGet(10000); GetAnomalyDetectorResponse response2 = client().execute(GetAnomalyDetectorAction.INSTANCE, request2).actionGet(10000); - ADTaskProfile taskProfile1 = response1.getDetectorProfile().getAdTaskProfile(); - ADTaskProfile taskProfile2 = response2.getDetectorProfile().getAdTaskProfile(); + TaskProfile taskProfile1 = response1.getDetectorProfile().getTaskProfile(); + TaskProfile taskProfile2 = response2.getDetectorProfile().getTaskProfile(); assertNotNull(taskProfile1.getNodeId()); assertNotNull(taskProfile2.getNodeId()); assertNotEquals(taskProfile1.getNodeId(), taskProfile2.getNodeId()); } - private GetAnomalyDetectorRequest taskProfileRequest(String detectorId) throws IOException { - return new GetAnomalyDetectorRequest(detectorId, Versions.MATCH_ANY, false, false, "", PROFILE, true, null); + private GetConfigRequest taskProfileRequest(String detectorId) throws IOException { + return new GetConfigRequest(detectorId, Versions.MATCH_ANY, false, false, "", PROFILE, true, null); } private long getExecutingADTask() { - ADStatsRequest adStatsRequest = new ADStatsRequest(getDataNodesArray()); + StatsRequest adStatsRequest = new StatsRequest(getDataNodesArray()); Set validStats = ImmutableSet.of(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName()); adStatsRequest.addAll(validStats); - StatsAnomalyDetectorResponse statsResponse = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5000); + StatsTimeSeriesResponse statsResponse = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5000); AtomicLong totalExecutingTask = new AtomicLong(0); - statsResponse.getAdStatsResponse().getADStatsNodesResponse().getNodes().forEach(node -> { + statsResponse.getAdStatsResponse().getStatsNodesResponse().getNodes().forEach(node -> { totalExecutingTask.getAndAdd((Long) node.getStatsMap().get(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName())); }); return totalExecutingTask.get(); diff --git a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java similarity index 76% rename from src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java index 7c3de7ed2..487ffb185 100644 --- a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/CronTransportActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -23,12 +23,11 @@ import org.junit.Before; import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -38,9 +37,14 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.transport.TransportService; import com.google.gson.JsonElement; @@ -65,14 +69,20 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); NodeStateManager tarnsportStatemanager = mock(NodeStateManager.class); - ModelManager modelManager = mock(ModelManager.class); + ADModelManager modelManager = mock(ADModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); - EntityCache entityCache = mock(EntityCache.class); - EntityColdStarter entityColdStarter = mock(EntityColdStarter.class); + ADCacheProvider cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache entityCache = mock(ADPriorityCache.class); + ADEntityColdStart entityColdStarter = mock(ADEntityColdStart.class); when(cacheProvider.get()).thenReturn(entityCache); ADTaskManager adTaskManager = mock(ADTaskManager.class); + ForecastCacheProvider forecastCacheProvider = mock(ForecastCacheProvider.class); + ForecastPriorityCache forecastCache = mock(ForecastPriorityCache.class); + ForecastColdStart forecastColdStarter = mock(ForecastColdStart.class); + when(forecastCacheProvider.get()).thenReturn(forecastCache); + ForecastTaskManager forecastTaskManager = mock(ForecastTaskManager.class); + action = new CronTransportAction( threadPool, clusterService, @@ -82,8 +92,11 @@ public void setUp() throws Exception { modelManager, featureManager, cacheProvider, + forecastCacheProvider, entityColdStarter, - adTaskManager + forecastColdStarter, + adTaskManager, + forecastTaskManager ); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorActionTests.java similarity index 84% rename from src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorActionTests.java index 93e291325..00c667b86 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -27,6 +27,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; +import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; @@ -35,6 +37,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.transport.TransportService; public class DeleteAnomalyDetectorActionTests extends OpenSearchIntegTestCase { @@ -60,6 +63,7 @@ public void setUp() throws Exception { clusterService, Settings.EMPTY, xContentRegistry(), + mock(NodeStateManager.class), adTaskManager ); response = new ActionListener() { @@ -83,18 +87,18 @@ public void testStatsAction() { @Test public void testDeleteRequest() throws IOException { - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - DeleteAnomalyDetectorRequest newRequest = new DeleteAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + DeleteConfigRequest newRequest = new DeleteConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); Assert.assertNull(newRequest.validate()); } @Test public void testEmptyDeleteRequest() { - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest(""); + DeleteConfigRequest request = new DeleteConfigRequest(""); ActionRequestValidationException exception = request.validate(); Assert.assertNotNull(exception); } @@ -103,14 +107,14 @@ public void testEmptyDeleteRequest() { public void testTransportActionWithAdIndex() { // DeleteResponse is not called because detector ID will not exist createIndex(".opendistro-anomaly-detector-jobs"); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); action.doExecute(mock(Task.class), request, response); } @Test public void testTransportActionWithoutAdIndex() throws IOException { // DeleteResponse is not called because detector ID will not exist - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); action.doExecute(mock(Task.class), request, response); } } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorTests.java similarity index 88% rename from src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java rename to src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorTests.java index 9d369b121..8092368e3 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/DeleteAnomalyDetectorTests.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -38,6 +38,7 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -52,7 +53,10 @@ import org.opensearch.index.get.GetResult; import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; import org.opensearch.tasks.Task; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; @@ -72,6 +76,7 @@ public class DeleteAnomalyDetectorTests extends AbstractTimeSeriesTest { private GetResponse getResponse; ClusterService clusterService; private Job jobParameter; + private NodeStateManager nodeStatemanager; @BeforeClass public static void setUpBeforeClass() { @@ -99,7 +104,8 @@ public void setUp() throws Exception { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); client = mock(Client.class); @@ -107,6 +113,7 @@ public void setUp() throws Exception { actionFilters = mock(ActionFilters.class); adTaskManager = mock(ADTaskManager.class); + nodeStatemanager = mock(NodeStateManager.class); action = new DeleteAnomalyDetectorTransportAction( transportService, actionFilters, @@ -114,6 +121,7 @@ public void setUp() throws Exception { clusterService, Settings.EMPTY, xContentRegistry(), + nodeStatemanager, adTaskManager ); @@ -126,32 +134,32 @@ public void setUp() throws Exception { public void testDeleteADTransportAction_FailDeleteResponse() { future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(true, true, false, false); action.doExecute(mock(Task.class), request, future); - verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(adTaskManager).deleteTasks(eq("1234"), any(), any()); verify(client, times(1)).delete(any(), any()); verify(future).onFailure(any(OpenSearchStatusException.class)); } public void testDeleteADTransportAction_NullAnomalyDetector() { future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(true, false, false, false); action.doExecute(mock(Task.class), request, future); - verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(adTaskManager).deleteTasks(eq("1234"), any(), any()); verify(client, times(3)).delete(any(), any()); } public void testDeleteADTransportAction_DeleteResponseException() { future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(true, false, true, false); action.doExecute(mock(Task.class), request, future); - verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(adTaskManager).deleteTasks(eq("1234"), any(), any()); verify(client, times(1)).delete(any(), any()); verify(future).onFailure(any(RuntimeException.class)); } @@ -165,10 +173,10 @@ public void testDeleteADTransportAction_LatestDetectorLevelTask() { ADTask adTask = ADTask.builder().state("RUNNING").build(); consumer.accept(Optional.of(adTask)); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(eq("1234"), any(), any(), eq(transportService), eq(true), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(eq("1234"), any(), any(), eq(transportService), eq(true), any()); future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(false, false, false, false); action.doExecute(mock(Task.class), request, future); @@ -178,7 +186,7 @@ public void testDeleteADTransportAction_LatestDetectorLevelTask() { public void testDeleteADTransportAction_JobRunning() { when(clusterService.state()).thenReturn(createClusterState()); future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(false, false, false, false); action.doExecute(mock(Task.class), request, future); @@ -188,7 +196,7 @@ public void testDeleteADTransportAction_JobRunning() { public void testDeleteADTransportAction_GetResponseException() { when(clusterService.state()).thenReturn(createClusterState()); future = mock(PlainActionFuture.class); - DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + DeleteConfigRequest request = new DeleteConfigRequest("1234"); setupMocks(false, false, false, true); action.doExecute(mock(Task.class), request, future); @@ -244,7 +252,7 @@ private void setupMocks( consumer.accept(Optional.of(ad)); } return null; - }).when(adTaskManager).getDetector(any(), any(), any()); + }).when(nodeStatemanager).getConfig(any(), any(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -252,7 +260,7 @@ private void setupMocks( function.execute(); return null; - }).when(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + }).when(adTaskManager).deleteTasks(eq("1234"), any(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -298,7 +306,8 @@ private void setupMocks( Instant.now(), 60L, TestHelpers.randomUser(), - jobParameter.getCustomResultIndex() + jobParameter.getCustomResultIndex(), + AnalysisType.AD ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) ), Collections.emptyMap(), diff --git a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java b/src/test/java/org/opensearch/timeseries/transport/EntityProfileTests.java similarity index 85% rename from src/test/java/org/opensearch/ad/transport/EntityProfileTests.java rename to src/test/java/org/opensearch/timeseries/transport/EntityProfileTests.java index 32b3226b4..d4e6cf8bf 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/EntityProfileTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyString; @@ -31,15 +31,12 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.ad.transport.ADEntityProfileAction; +import org.opensearch.ad.transport.ADEntityProfileTransportAction; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -48,11 +45,16 @@ import org.opensearch.core.transport.TransportResponse; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.tasks.Task; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.EntityProfileName; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.model.ModelProfileOnNode; import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; @@ -77,8 +79,8 @@ public class EntityProfileTests extends AbstractTimeSeriesTest { private TransportService transportService; private Settings settings; private ClusterService clusterService; - private CacheProvider cacheProvider; - private EntityProfileTransportAction action; + private ADCacheProvider cacheProvider; + private ADEntityProfileTransportAction action; private Task task; private PlainActionFuture future; private TransportAddress transportAddress1; @@ -125,7 +127,8 @@ public void setUp() throws Exception { TransportService.NOOP_TRANSPORT_INTERCEPTOR, x -> null, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ); settings = Settings.EMPTY; @@ -133,18 +136,18 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); - cacheProvider = mock(CacheProvider.class); - EntityCache cache = mock(EntityCache.class); + cacheProvider = mock(ADCacheProvider.class); + ADPriorityCache cache = mock(ADPriorityCache.class); updates = 1L; when(cache.getTotalUpdates(anyString(), anyString())).thenReturn(updates); when(cache.isActive(anyString(), anyString())).thenReturn(isActive); - when(cache.getLastActiveMs(anyString(), anyString())).thenReturn(lastActiveTimestamp); + when(cache.getLastActiveTime(anyString(), anyString())).thenReturn(lastActiveTimestamp); Map modelSizeMap = new HashMap<>(); modelSizeMap.put(modelId, modelSize); when(cache.getModelSize(anyString())).thenReturn(modelSizeMap); when(cacheProvider.get()).thenReturn(cache); - action = new EntityProfileTransportAction(actionFilters, transportService, settings, hashRing, clusterService, cacheProvider); + action = new ADEntityProfileTransportAction(actionFilters, transportService, settings, hashRing, clusterService, cacheProvider); future = new PlainActionFuture<>(); transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); @@ -165,7 +168,7 @@ public void sendRequest( TransportRequestOptions options, TransportResponseHandler handler ) { - if (EntityProfileAction.NAME.equals(action)) { + if (ADEntityProfileAction.NAME.equals(action)) { sender.sendRequest(connection, action, request, options, entityProfileHandler(handler)); } else { sender.sendRequest(connection, action, request, options, handler); @@ -187,7 +190,7 @@ public void sendRequest( TransportRequestOptions options, TransportResponseHandler handler ) { - if (EntityProfileAction.NAME.equals(action)) { + if (ADEntityProfileAction.NAME.equals(action)) { sender.sendRequest(connection, action, request, options, entityFailureProfileandler(handler)); } else { sender.sendRequest(connection, action, request, options, handler); @@ -236,7 +239,7 @@ public void handleResponse(T response) { .handleException( new ConnectTransportException( new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()), - EntityProfileAction.NAME + ADEntityProfileAction.NAME ) ); } @@ -254,7 +257,7 @@ public String executor() { } private void registerHandler(FakeNode node) { - new EntityProfileTransportAction( + new ADEntityProfileTransportAction( new ActionFilters(Collections.emptySet()), node.transportService, Settings.EMPTY, @@ -265,15 +268,15 @@ private void registerHandler(FakeNode node) { } public void testInvalidRequest() { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.empty()); action.doExecute(task, request, future); - assertException(future, TimeSeriesException.class, EntityProfileTransportAction.NO_NODE_FOUND_MSG); + assertException(future, TimeSeriesException.class, ADEntityProfileTransportAction.NO_NODE_FOUND_MSG); } public void testLocalNodeHit() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); action.doExecute(task, request, future); @@ -283,7 +286,7 @@ public void testLocalNodeHit() { public void testAllHit() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); request = new EntityProfileRequest(detectorId, entity, all); @@ -300,7 +303,7 @@ public void testGetRemoteUpdateResponse() { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; - action = new EntityProfileTransportAction( + action = new ADEntityProfileTransportAction( actionFilters, realTransportService, settings, @@ -309,7 +312,7 @@ public void testGetRemoteUpdateResponse() { cacheProvider ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -330,7 +333,7 @@ public void testGetRemoteFailureResponse() { TransportService realTransportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; - action = new EntityProfileTransportAction( + action = new ADEntityProfileTransportAction( actionFilters, realTransportService, settings, @@ -339,7 +342,7 @@ public void testGetRemoteFailureResponse() { cacheProvider ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -359,7 +362,7 @@ public void testResponseToXContent() throws IOException, JsonPathNotFoundExcepti EntityProfileResponse response = builder.build(); String json = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); assertEquals(lastActiveTimestamp, JsonDeserializer.getLongValue(json, EntityProfileResponse.LAST_ACTIVE_TS)); - assertEquals(modelSize, JsonDeserializer.getChildNode(json, ADCommonName.MODEL, CommonName.MODEL_SIZE_IN_BYTES).getAsLong()); + assertEquals(modelSize, JsonDeserializer.getChildNode(json, CommonName.MODEL, CommonName.MODEL_SIZE_IN_BYTES).getAsLong()); } public void testResponseHashCodeEquals() { @@ -376,8 +379,8 @@ public void testResponseHashCodeEquals() { } public void testEntityProfileName() { - assertEquals("state", EntityProfileName.getName(ADCommonName.STATE).getName()); - assertEquals("models", EntityProfileName.getName(ADCommonName.MODELS).getName()); + assertEquals("state", EntityProfileName.getName(CommonName.STATE).getName()); + assertEquals("models", EntityProfileName.getName(CommonName.MODELS).getName()); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> EntityProfileName.getName("abc")); assertEquals(exception.getMessage(), ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); } diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java b/src/test/java/org/opensearch/timeseries/transport/SearchAnomalyDetectorInfoActionTests.java similarity index 82% rename from src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java rename to src/test/java/org/opensearch/timeseries/transport/SearchAnomalyDetectorInfoActionTests.java index f06761bb6..47a5d0877 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java +++ b/src/test/java/org/opensearch/timeseries/transport/SearchAnomalyDetectorInfoActionTests.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -30,6 +30,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoAction; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -46,15 +48,15 @@ import org.opensearch.transport.TransportService; public class SearchAnomalyDetectorInfoActionTests extends OpenSearchIntegTestCase { - private SearchAnomalyDetectorInfoRequest request; - private ActionListener response; + private SearchConfigInfoRequest request; + private ActionListener response; private SearchAnomalyDetectorInfoTransportAction action; private Task task; private ClusterService clusterService; private Client client; private ThreadPool threadPool; ThreadContext threadContext; - private PlainActionFuture future; + private PlainActionFuture future; @Override @Before @@ -67,9 +69,9 @@ public void setUp() throws Exception { clusterService() ); task = mock(Task.class); - response = new ActionListener() { + response = new ActionListener() { @Override - public void onResponse(SearchAnomalyDetectorInfoResponse response) { + public void onResponse(SearchConfigInfoResponse response) { Assert.assertEquals(response.getCount(), 0); Assert.assertEquals(response.isNameExists(), false); } @@ -100,14 +102,14 @@ public void onFailure(Exception e) { @Test public void testSearchCount() throws IOException { // Anomaly Detectors index will not exist, onResponse will be called - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest(null, "count"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest(null, "count"); action.doExecute(task, request, response); } @Test public void testSearchMatch() throws IOException { // Anomaly Detectors index will not exist, onResponse will be called - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); action.doExecute(task, request, response); } @@ -119,11 +121,11 @@ public void testSearchInfoAction() { @Test public void testSearchInfoRequest() throws IOException { - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - SearchAnomalyDetectorInfoRequest newRequest = new SearchAnomalyDetectorInfoRequest(input); + SearchConfigInfoRequest newRequest = new SearchConfigInfoRequest(input); Assert.assertEquals(request.getName(), newRequest.getName()); Assert.assertEquals(request.getRawPath(), newRequest.getRawPath()); Assert.assertNull(newRequest.validate()); @@ -131,11 +133,11 @@ public void testSearchInfoRequest() throws IOException { @Test public void testSearchInfoResponse() throws IOException { - SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(1, true); + SearchConfigInfoResponse response = new SearchConfigInfoResponse(1, true); BytesStreamOutput out = new BytesStreamOutput(); response.writeTo(out); StreamInput input = out.bytes().streamInput(); - SearchAnomalyDetectorInfoResponse newResponse = new SearchAnomalyDetectorInfoResponse(input); + SearchConfigInfoResponse newResponse = new SearchConfigInfoResponse(input); Assert.assertEquals(response.getCount(), newResponse.getCount()); Assert.assertEquals(response.isNameExists(), newResponse.isNameExists()); Assert.assertNotNull(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); @@ -156,9 +158,9 @@ public void testSearchInfoResponse_CountSuccessWithEmptyResponse() throws IOExce client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "count"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "count"); action.doExecute(task, request, future); - verify(future).onResponse(any(SearchAnomalyDetectorInfoResponse.class)); + verify(future).onResponse(any(SearchConfigInfoResponse.class)); } public void testSearchInfoResponse_MatchSuccessWithEmptyResponse() throws IOException { @@ -176,9 +178,9 @@ public void testSearchInfoResponse_MatchSuccessWithEmptyResponse() throws IOExce client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); action.doExecute(task, request, future); - verify(future).onResponse(any(SearchAnomalyDetectorInfoResponse.class)); + verify(future).onResponse(any(SearchConfigInfoResponse.class)); } public void testSearchInfoResponse_CountRuntimeException() throws IOException { @@ -194,7 +196,7 @@ public void testSearchInfoResponse_CountRuntimeException() throws IOException { client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "count"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "count"); action.doExecute(task, request, future); verify(future).onFailure(any(RuntimeException.class)); } @@ -212,7 +214,7 @@ public void testSearchInfoResponse_MatchRuntimeException() throws IOException { client, clusterService ); - SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + SearchConfigInfoRequest request = new SearchConfigInfoRequest("testDetector", "match"); action.doExecute(task, request, future); verify(future).onFailure(any(RuntimeException.class)); } diff --git a/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java b/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java index d4241fc4f..031c234c8 100644 --- a/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java +++ b/src/test/java/org/opensearch/timeseries/util/ClientUtilTests.java @@ -78,7 +78,8 @@ public void testAsyncRequestOnSuccess() throws InterruptedException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); BiConsumer> consumer = (request, actionListener) -> { // simulate successful operation @@ -122,7 +123,8 @@ public void testExecuteOnSuccess() throws InterruptedException { new double[] { randomDouble(), randomDouble() }, new double[][] { new double[] { randomDouble(), randomDouble() } }, new double[] { randomDouble() }, - randomDoubleBetween(1.1, 10.0, true) + randomDoubleBetween(1.1, 10.0, true), + null ); doAnswer(invocationOnMock -> { ((ActionListener) invocationOnMock.getArguments()[2]).onResponse(expected); diff --git a/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java index 5c66d8f54..a74dbe83c 100644 --- a/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java +++ b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java @@ -106,4 +106,36 @@ public static ClusterState state(int numDataNodes) { } return state(new ClusterName("test"), clusterManagerNode, clusterManagerNode, allNodes); } + + public static void main(String args[]) { + long start = System.currentTimeMillis(); + boolean condition = true; + int index = 1; + int getCutValue = 2; + int getCutDimension = 3; + String leftBox = "abcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"; + String rightBox = "deffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + try { + for (int i = 0; i < 10000; i++) { + if (condition) { + throw new IllegalStateException( + " incorrect bounding state at index " + + index + + " cut value " + + getCutValue + + "cut dimension " + + getCutDimension + + " left Box " + + leftBox.toString() + + " right box " + + rightBox.toString() + ); + } + } + + } finally { + long finish = System.currentTimeMillis(); + System.out.println(finish - start); + } + } } diff --git a/src/test/java/test/org/opensearch/ad/util/FakeNode.java b/src/test/java/test/org/opensearch/ad/util/FakeNode.java index 58f3f14bb..1fc43e62d 100644 --- a/src/test/java/test/org/opensearch/ad/util/FakeNode.java +++ b/src/test/java/test/org/opensearch/ad/util/FakeNode.java @@ -48,6 +48,7 @@ import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; import org.opensearch.tasks.TaskManager; import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.tasks.MockTaskManager; import org.opensearch.threadpool.ThreadPool; @@ -79,7 +80,8 @@ public FakeNode( new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, new NamedWriteableRegistry(ClusterModule.getNamedWriteables()), - new NoneCircuitBreakerService() + new NoneCircuitBreakerService(), + NoopTracer.INSTANCE ) { @Override public TransportAddress[] addressesFromString(String address) { @@ -90,7 +92,8 @@ public TransportAddress[] addressesFromString(String address) { transportInterceptor, boundTransportAddressDiscoveryNodeFunction, null, - Collections.emptySet() + Collections.emptySet(), + NoopTracer.INSTANCE ) { @Override protected TaskManager createTaskManager( diff --git a/src/test/java/test/org/opensearch/ad/util/MLUtil.java b/src/test/java/test/org/opensearch/ad/util/MLUtil.java index 6b6bb39af..8f9025f46 100644 --- a/src/test/java/test/org/opensearch/ad/util/MLUtil.java +++ b/src/test/java/test/org/opensearch/ad/util/MLUtil.java @@ -14,17 +14,20 @@ import static java.lang.Math.PI; import java.time.Clock; +import java.time.Instant; import java.util.ArrayDeque; +import java.util.Deque; import java.util.HashMap; import java.util.Map; -import java.util.Queue; +import java.util.Optional; import java.util.Random; import java.util.stream.IntStream; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.common.collect.Tuple; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -53,13 +56,13 @@ private static String randomString(int targetStringLength) { .toString(); } - public static Queue createQueueSamples(int size) { - Queue res = new ArrayDeque<>(); - IntStream.range(0, size).forEach(i -> res.offer(new double[] { random.nextDouble() })); + public static Deque createQueueSamples(int size) { + Deque res = new ArrayDeque<>(); + IntStream.range(0, size).forEach(i -> res.offer(new Sample(new double[] { random.nextDouble() }, Instant.now(), Instant.now()))); return res; } - public static ModelState randomModelState(RandomModelStateConfig config) { + public static ModelState randomModelState(RandomModelStateConfig config) { boolean fullModel = config.getFullModel() != null && config.getFullModel().booleanValue() ? true : false; float priority = config.getPriority() != null ? config.getPriority() : random.nextFloat(); String detectorId = config.getId() != null ? config.getId() : randomString(15); @@ -75,27 +78,37 @@ public static ModelState randomModelState(RandomModelStateConfig co } else { entity = Entity.createSingleAttributeEntity("", ""); } - EntityModel model = null; + Pair> model = null; if (fullModel) { model = createNonEmptyModel(detectorId, sampleSize, entity); } else { model = createEmptyModel(entity, sampleSize); } - return new ModelState<>(model, detectorId, detectorId, ModelType.ENTITY.getName(), clock, priority); + return new ModelState( + model.getLeft(), + detectorId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + priority, + null, + Optional.of(entity), + model.getRight() + ); } - public static EntityModel createEmptyModel(Entity entity, int sampleSize) { - Queue samples = createQueueSamples(sampleSize); - return new EntityModel(entity, samples, null); + public static Pair> createEmptyModel(Entity entity, int sampleSize) { + Deque samples = createQueueSamples(sampleSize); + return Pair.of(null, samples); } - public static EntityModel createEmptyModel(Entity entity) { + public static Pair> createEmptyModel(Entity entity) { return createEmptyModel(entity, random.nextInt(minSampleSize)); } - public static EntityModel createNonEmptyModel(String detectorId, int sampleSize, Entity entity) { - Queue samples = createQueueSamples(sampleSize); + public static Pair> createNonEmptyModel(String detectorId, int sampleSize, Entity entity) { + Deque samples = createQueueSamples(sampleSize); int numDataPoints = random.nextInt(1000) + TimeSeriesSettings.NUM_MIN_SAMPLES; ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest( ThresholdedRandomCutForest @@ -116,11 +129,10 @@ public static EntityModel createNonEmptyModel(String detectorId, int sampleSize, for (int i = 0; i < numDataPoints; i++) { trcf.process(new double[] { random.nextDouble() }, i); } - EntityModel entityModel = new EntityModel(entity, samples, trcf); - return entityModel; + return Pair.of(trcf, samples); } - public static EntityModel createNonEmptyModel(String detectorId) { + public static Pair> createNonEmptyModel(String detectorId) { return createNonEmptyModel(detectorId, random.nextInt(minSampleSize), Entity.createSingleAttributeEntity("", "")); } @@ -179,11 +191,11 @@ static double[] getDataD(int num, double amplitude, double noise, long seed) { * @param rcfConfig RCF config * @return models and return training samples */ - public static Tuple, ThresholdedRandomCutForest> prepareModel( + public static Tuple, ThresholdedRandomCutForest> prepareModel( int inputDimension, ThresholdedRandomCutForest.Builder rcfConfig ) { - Queue samples = new ArrayDeque<>(); + Deque samples = new ArrayDeque<>(); Random r = new Random(); ThresholdedRandomCutForest rcf = new ThresholdedRandomCutForest(rcfConfig); @@ -192,7 +204,7 @@ public static Tuple, ThresholdedRandomCutForest> prepareModel( for (int i = 0; i < trainDataNum; i++) { double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); - samples.add(point); + samples.add(new Sample(point, Instant.now(), Instant.now())); rcf.process(point, 0); }