/**
 *    Copyright (C) 2023-present MongoDB, Inc.
 *
 *    This program is free software: you can redistribute it and/or modify
 *    it under the terms of the Server Side Public License, version 1,
 *    as published by MongoDB, Inc.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    Server Side Public License for more details.
 *
 *    You should have received a copy of the Server Side Public License
 *    along with this program. If not, see
 *    <http://www.mongodb.com/licensing/server-side-public-license>.
 *
 *    As a special exception, the copyright holders give permission to link the
 *    code of portions of this program with the OpenSSL library under certain
 *    conditions as described in each individual source file and distribute
 *    linked combinations including the program with the OpenSSL library. You
 *    must comply with the Server Side Public License in all respects for
 *    all of the code used other than as permitted herein. If you modify file(s)
 *    with this exception, you may extend this exception to your version of the
 *    file(s), but you are not obligated to do so. If you do not wish to do so,
 *    delete this exception statement from your version. If you delete this
 *    exception statement from all source files in the program, then also delete
 *    it in the license file.
 */

#pragma once

#include "mongo/base/status.h"
#include "mongo/util/future_util.h"

namespace mongo {
namespace primary_only_service_helpers {

using RetryabilityPredicate = std::function<bool(const Status&)>;

const auto kDefaultRetryabilityPredicate = [](const Status& status) {
    // Always attempt to retry on any type of retryable error. Also retry on errors
    // from stray killCursors and killOp commands being run. Cancellation and
    // NotPrimary errors may indicate the primary-only service Instance will be shut
    // down or is shutting down now. However, it is also possible that the error
    // originated from a remote response rather than being an error generated by
    // this shard itself. Defer whether or not to retry to the state of the
    // cancellation token. This means the body of the AsyncTry may continue to run
    // and error more times until the cancellation token is actually canceled if
    // this shard was in the midst of stepping down.
    return status.isA<ErrorCategory::RetriableError>() ||
        status == ErrorCodes::FailedToSatisfyReadPreference ||
        status.isA<ErrorCategory::CursorInvalidatedError>() || status == ErrorCodes::Interrupted ||
        status.isA<ErrorCategory::CancellationError>() ||
        status.isA<ErrorCategory::NotPrimaryError>() ||
        status.isA<ErrorCategory::NetworkTimeoutError>();
};

const auto kAlwaysRetryPredicate = [](const Status& status) {
    return true;
};

/**
 * A fluent-style API for executing asynchronous, future-returning try-until loops, specialized
 * around typical retry logic for components within primary-only services.
 *
 * Example usage:
 *
 *      ExecutorFuture<void> future =
 *          primary_only_service_helpers::WithAutomaticRetry(
 *              [this, chainCtx] { chainCtx->moreToCome = doOneBatch(); })
 *              .onTransientError([](const Status& status) {
 *                  LOGV2(123,
 *                        "Transient error while doing batch",
 *                        "error"_attr = redact(status));
 *              })
 *              .onUnrecoverableError([](const Status& status) {
 *                  LOGV2_ERROR(456,
 *                              "Operation-fatal error while doing batch",
 *                              "error"_attr = redact(status));
 *              })
 *              .until([chainCtx](const Status& status) {
 *                  return status.isOK() && !chainCtx->moreToCome;
 *              })
 *              .on(std::move(executor), std::move(cancelToken));
 */
template <typename BodyCallable>
class [[nodiscard]] WithAutomaticRetry {
public:
    explicit WithAutomaticRetry(BodyCallable&& body,
                                RetryabilityPredicate isRetryable = kDefaultRetryabilityPredicate)
        : _isRetryable{std::move(isRetryable)}, _body{std::move(body)} {}

    decltype(auto) onTransientError(unique_function<void(const Status&)> onTransientError) && {
        invariant(!_onTransientError, "Cannot call onTransientError() twice");
        _onTransientError = std::move(onTransientError);
        return std::move(*this);
    }

    decltype(auto) onUnrecoverableError(
        unique_function<void(const Status&)> onUnrecoverableError) && {
        invariant(!_onUnrecoverableError, "Cannot call onUnrecoverableError() twice");
        _onUnrecoverableError = std::move(onUnrecoverableError);
        return std::move(*this);
    }

    template <typename StatusType>
    auto until(unique_function<bool(const StatusType&)> condition) && {
        invariant(_onTransientError, "Must call onTransientError() first");
        invariant(_onUnrecoverableError, "Must call onUnrecoverableError() first");

        return AsyncTry<BodyCallable>(std::move(_body))
            .until([onTransientError = std::move(_onTransientError),
                    onUnrecoverableError = std::move(_onUnrecoverableError),
                    condition = std::move(condition),
                    isRetryable = _isRetryable](const StatusType& statusOrStatusWith) {
                Status status = _getStatus(statusOrStatusWith);

                if (!status.isOK()) {
                    if (isRetryable(status)) {
                        onTransientError(status);
                    } else {
                        onUnrecoverableError(status);
                        return true;
                    }
                }

                return condition(statusOrStatusWith);
            });
    }

private:
    static const Status& _getStatus(const Status& status) {
        return status;
    }

    template <typename ValueType>
    static const Status& _getStatus(const StatusWith<ValueType>& statusWith) {
        return statusWith.getStatus();
    }
    RetryabilityPredicate _isRetryable;
    BodyCallable _body;

    unique_function<void(const Status&)> _onTransientError;
    unique_function<void(const Status&)> _onUnrecoverableError;
};

}  // namespace primary_only_service_helpers
}  // namespace mongo
