/*******************************************************************************
* Copyright (C) 2021 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneapi::mkl::lapack::getrf_batch and 
*       oneapi::mkl::lapack::getri_batch to perform batched LU factorization 
*       and matrix inverse on a SYCL device (CPU, GPU).
*
*       The supported floating point data types for matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*******************************************************************************/

#include <iostream>
#include <complex>
#include <vector>
#include "mkl.h"
#include "oneapi/mkl/lapack.hpp"
#include "common_for_examples.hpp"

//
// Example for batched LU factorization consisting of initialization of
// a square dense matrices, Ai.
// The  LU factorization
// Ai = Pi * Li * Ui
// is computed for each matrix, Ai, and used to compute its inverse, inv(Ai),
// in the same array (in-place computation)
// Finally the results are post processed.
//

template <typename data_t>
void run_getri_batch_example(const sycl::device& device)
{
    // Input arguments
    std::int64_t n           = 4;
    std::int64_t lda         = 6;
    std::int64_t stride_a    = lda * n;
    std::int64_t stride_ipiv = n;
    std::int64_t batch_size  = 3;

    // Variable holding status of calculations
    std::int64_t info = 0;

    // Asynchronous error handler
    auto error_handler = [&] (sycl::exception_list exceptions) {
        for (auto const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            } catch(oneapi::mkl::lapack::exception const& e) {
                // Handle LAPACK related exceptions happened during asynchronous call
                info = e.info();
                std::cout << "Unexpected exception caught during asynchronous LAPACK operation:\n" << e.what() << "\ninfo: " << e.info() << std::endl;
            } catch(sycl::exception const& e) {
                // Handle not LAPACK related exceptions happened during asynchronous call
                std::cout << "Unexpected exception caught during asynchronous operation:\n" << e.what() << std::endl;
                info = -1;
            }
        }
    };

    // Create execution queue for selected device
    sycl::queue queue(device, error_handler);
    sycl::context context = queue.get_context();

    // Allocate shared memory for matrices
    std::int64_t size_a    = stride_a * batch_size;
    std::int64_t size_ipiv = stride_ipiv * batch_size; 
    data_t  *A = (data_t*) sycl::malloc_shared(size_a * sizeof(data_t), device, context);
    std::int64_t *ipiv = (std::int64_t*) sycl::malloc_shared(size_ipiv * sizeof(std::int64_t), device, context);
    data_t *getrf_batch_scratchpad = nullptr;
    data_t *getri_batch_scratchpad = nullptr;

    // Initialize batch of matrices to have random values
    for (int imat = 0; imat < batch_size; imat++) {
        int offset = imat * stride_a;
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                A[offset + j + i*lda] = rand_scalar<data_t>();
    }

    try {
        // Get sizes of scratchpads for calculations
        std::int64_t getrf_batch_scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size<data_t>(queue, n, n, lda, stride_a, stride_ipiv, batch_size);
        std::int64_t getri_batch_scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size<data_t>(queue, n, lda, stride_a, stride_ipiv, batch_size);

        getrf_batch_scratchpad = (data_t*) sycl::malloc_shared(getrf_batch_scratchpad_size * sizeof(data_t), device, context);
        getri_batch_scratchpad = (data_t*) sycl::malloc_shared(getri_batch_scratchpad_size * sizeof(data_t), device, context);

        // Submit batched LU factorization on device queue
        auto getrf_batch_done_event = oneapi::mkl::lapack::getrf_batch(queue, n, n, A, lda, stride_a, ipiv, stride_ipiv, batch_size, getrf_batch_scratchpad, getrf_batch_scratchpad_size);

        // Submit batched matrix inverse on device queue. Its execution is dependent on the completion of the batched LU factorization (waits on getrf_batch_done_event)
        auto getri_batch_done_event = oneapi::mkl::lapack::getri_batch(queue, n, A, lda, stride_a, ipiv, stride_ipiv, batch_size, getri_batch_scratchpad, getri_batch_scratchpad_size, { getrf_batch_done_event });
    } catch(oneapi::mkl::lapack::exception const& e) {
        // Handle LAPACK related exceptions happened during synchronous call
        std::cout << "Unexpected exception caught during synchronous call to LAPACK API:\nreason: " << e.what() << "\ninfo: " << e.info() << std::endl;
        info = e.info();
    } catch(sycl::exception const& e) {
        // Handle not LAPACK related exceptions happened during synchronous call
        std::cout << "Unexpected exception caught during synchronous call to SYCL API:\n" << e.what() << std::endl;
        info = -1;
    }

    // Wait for all submitted jobs to finish
    queue.wait_and_throw();

    sycl::free(getri_batch_scratchpad, context);
    sycl::free(getrf_batch_scratchpad, context);
    sycl::free(ipiv, context);
    sycl::free(A, context);

    std::cout << "getri_batch " << ((info == 0) ? "ran OK" : "FAILED") << std::endl;
    return;
}


//
// Description of example setup, APIs used and supported floating point type precisions
//

void print_example_banner()
{
    std::cout << "" << std::endl;
    std::cout << "##################################################################################" << std::endl;
    std::cout << "# Batched LU factorization and matrix inverse using Unified Shared Memory Example:" << std::endl;
    std::cout << "# Computes the LU factorization for a batch of dense square matrices as " << std::endl;
    std::cout << "#                 Ai = Pi * Li * Ui " << std::endl;
    std::cout << "# where each matrix, Ai, is assumed to be stored in a contiguous block of " << std::endl;
    std::cout << "# memory, A, at a constant stride from each other." << std::endl;
    std::cout << "# The LU factorization of each matrix is then used to compute its " << std::endl;
    std::cout << "# inverse. Each matrix, Ai, is overwritten with its inverse inv(Ai) " << std::endl;
    std::cout << "# in the same contiguous block of memory (in-place computation) " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using APIs:" << std::endl;
    std::cout << "#   getrf_batch and getri_batch" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "##################################################################################" << std::endl;
    std::cout << std::endl;
}


//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
//      -DSYCL_DEVICES_cpu -- only runs SYCL CPU device.
//      -DSYCL_DEVICES_gpu -- only runs SYCL GPU device.
//      -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices.
//
//  For each device selected and each data type supported, matrix inversion example
//  runs with all supported data types.
//
int main(int argc, char **argv)
{
    print_example_banner();

    // Find list of devices
    std::list<my_sycl_device_types> listOfDevices;
    set_list_of_devices(listOfDevices);

    for(auto &deviceType : listOfDevices) {
        sycl::device myDev;
        bool myDevIsFound = false;
        get_sycl_device(myDev, myDevIsFound, deviceType);

        if(myDevIsFound) {
            std::cout << std::endl << "Running getri_batch examples on " << sycl_device_names[deviceType] << "." << std::endl;

            std::cout << "\tRunning with single precision real data type:" << std::endl;
            run_getri_batch_example<float>(myDev);

            std::cout << "\tRunning with single precision complex data type:" << std::endl;
            run_getri_batch_example<std::complex<float>>(myDev);

            if (isDoubleSupported(myDev)) {
                std::cout << "\tRunning with double precision real data type:" << std::endl;
                run_getri_batch_example<double>(myDev);

                std::cout << "\tRunning with double precision complex data type:" << std::endl;
                run_getri_batch_example<std::complex<double>>(myDev);
            } else {
                std::cout << "\tDouble precision not supported on this device " << std::endl;
                std::cout << std::endl;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[deviceType] << " devices found; Fail on missing devices is enabled.\n";
            return 1;
#else
            std::cout << "No " << sycl_device_names[deviceType] << " devices found; skipping " << sycl_device_names[deviceType] << " tests.\n";
#endif
        }
    }
    mkl_free_buffers();
    return 0;
}
