001package com.identityworksllc.iiq.common.task;
002
003import com.identityworksllc.iiq.common.logging.SLogger;
004import com.identityworksllc.iiq.common.threads.SailPointWorker;
005import org.apache.commons.logging.Log;
006import sailpoint.api.SailPointContext;
007import sailpoint.tools.GeneralException;
008
009import java.util.HashMap;
010import java.util.Map;
011import java.util.Objects;
012
013/**
014 * The worker for handling each input object. The input type is always a list
015 * of the given objects, to support batching, but for non-batched situations,
016 * the list may be of length 1.
017 *
018 * None of the inputs may be null.
019 */
020public class ThreadExecutorWorker<T> extends SailPointWorker {
021    /**
022     * The ultimate consumer of each item. In the default setup, this will
023     * invoke {@link AbstractThreadedTask#threadExecute(SailPointContext, Map, Object)},
024     * but consumers can do whatever they want.
025     */
026    private final PrivateContextObjectConsumer<T> consumer;
027
028    /**
029     * The objects to iterate in this thread
030     */
031    private final Iterable<T> objects;
032
033    /**
034     * The task context, used to check status, increment counters, etc
035     */
036    private final ThreadedTaskListener<T> taskListener;
037
038    /**
039     * Basic constructor, corresponds to {@link ThreadWorkerCreator}.
040     *
041     * None of the inputs may be null.
042     *
043     * @param objects The objects to iterate over
044     * @param consumer The consumer of those objects (i.e., who is implementing threadExecute)
045     * @param taskContext The task context
046     */
047    public ThreadExecutorWorker(Iterable<T> objects, PrivateContextObjectConsumer<T> consumer, ThreadedTaskListener<T> taskContext) {
048        this.objects = Objects.requireNonNull(objects);
049        this.consumer = Objects.requireNonNull(consumer);
050        this.taskListener = Objects.requireNonNull(taskContext);
051    }
052
053    /**
054     * Invokes {@link PrivateContextObjectConsumer#threadExecute(SailPointContext, Map, Object)} for
055     * each object in the list. Also invokes a variety of callbacks via the taskContext.
056     *
057     * @param threadContext The thread context
058     * @param logger        The log attached to this Worker
059     * @return always null
060     * @throws InterruptedException if the thread has been interrupted
061     */
062    @Override
063    public Object execute(SailPointContext threadContext, Log logger) throws InterruptedException {
064        SLogger slogger = new SLogger(logger);
065        boolean skip = false;
066        if (!taskListener.isTerminated()) {
067            try {
068                taskListener.beforeBatch(threadContext);
069            } catch(GeneralException e) {
070                skip = true;
071                logger.error("Caught an error invoking beforeBatch; skipping batch", e);
072            }
073            if (!skip) {
074                for (T in : objects) {
075                    taskListener.beforeExecution(Thread.currentThread(), in);
076
077                    if (Thread.interrupted() || taskListener.isTerminated()) {
078                        throw new InterruptedException("Thread interrupted");
079                    }
080                    Map<String, Object> args = new HashMap<>();
081                    args.put("context", threadContext);
082                    args.put("log", slogger);
083                    args.put("logger", slogger);
084                    args.put("object", in);
085                    args.put("worker", this);
086                    args.put("taskListener", this.taskListener);
087                    args.put("monitor", this.monitor);
088                    try {
089                        consumer.threadExecute(threadContext, args, in);
090                        taskListener.handleSuccess(in);
091                        threadContext.commitTransaction();
092                    } catch (Exception e) {
093                        taskListener.handleException(e);
094                        taskListener.handleFailure(in);
095                    }
096                }
097            }
098        }
099
100        if (!skip) {
101            try {
102                taskListener.afterBatch(threadContext);
103            } catch (GeneralException e) {
104                logger.error("Caught an error invoking afterBatch", e);
105            }
106        }
107        return null;
108    }
109}