001package com.identityworksllc.iiq.common.task;
002
003import com.identityworksllc.iiq.common.CommonConstants;
004import com.identityworksllc.iiq.common.Functions;
005import com.identityworksllc.iiq.common.iterators.BatchingIterator;
006import com.identityworksllc.iiq.common.iterators.TransformingIterator;
007import com.identityworksllc.iiq.common.threads.SailPointWorker;
008import org.apache.commons.logging.Log;
009import org.apache.commons.logging.LogFactory;
010import sailpoint.api.SailPointContext;
011import sailpoint.api.SailPointFactory;
012import sailpoint.object.Attributes;
013import sailpoint.object.TaskResult;
014import sailpoint.object.TaskSchedule;
015import sailpoint.task.AbstractTaskExecutor;
016import sailpoint.task.TaskMonitor;
017import sailpoint.tools.GeneralException;
018import sailpoint.tools.Util;
019
020import java.util.ArrayList;
021import java.util.Collections;
022import java.util.Iterator;
023import java.util.List;
024import java.util.Map;
025import java.util.concurrent.ExecutorService;
026import java.util.concurrent.Executors;
027import java.util.concurrent.TimeUnit;
028import java.util.concurrent.atomic.AtomicBoolean;
029import java.util.concurrent.atomic.AtomicInteger;
030import java.util.function.BiConsumer;
031import java.util.function.Consumer;
032
033/**
034 * An abstract superclass for nearly all custom multi-threaded SailPoint tasks. This task
035 * will retrieve a list of objects and then pass each of them to a processor method
036 * in the subclass in parallel.
037 *
038 * The overall flow is:
039 *
040 *  1. Invoke {@link #parseArgs(Attributes)} to extract task arguments. Subclasses should
041 *     override this to retrieve their own parameters.
042 *  2. Invoke {@link #getObjectIterator(SailPointContext, Attributes)} to retrieve a list of
043 *     items to iterate over. This iterator can be "streaming" or not.
044 *  3. Invoke {@link #submitAndWait(SailPointContext, TaskResult, Iterator)} to begin processing
045 *     of the list of items. This can be replaced by a subclass, but the default flow is described
046 *     below. It is unlikely that you will need to change it.
047 *  4. Clean up the thread pool and update the TaskResult with outcomes.
048 *
049 * Submit-and-wait proceeds as follows:
050 *
051 *  1. Retrieve the batch size from {@link #getBatchSize()}
052 *  2. Create a thread pool with the specified number of threads in it.
053 *  3. For each item, invoke the subclass's {@link #threadExecute(SailPointContext, Map, T)} method,
054 *     passing the current item in the third parameter. If a batch size is set, more than one
055 *     item will be passed in a single worker thread, eliminating the need to build and destroy
056 *     lots of private contexts. This will likely be more efficient for large operations.
057 *
058 * Via the {@link SailPointWorker} class, the {@link #threadExecute(SailPointContext, Map, T)} method
059 * will also receive an appropriately thread-specific SailPointContext object that can be used without
060 * worrying about transaction length or overlap.
061 *
062 * Subclasses can override most parts of this process by extending the various protected methods.
063 *
064 * Subclasses cannot direct receive a termination notification, but can register any number of
065 * termination handlers by invoking {@link #addTerminationHandler(Functions.GenericCallback)}.
066 * This class makes a best effort to invoke all termination handlers.
067 *
068 * @param <T> The type of input object that will be passed to threadExecute.
069 *
070 * @author Devin Rosenbauer
071 * @author Instrumental Identity
072 */
073public abstract class AbstractThreadedTask<T> extends AbstractTaskExecutor implements PrivateContextObjectConsumer<T> {
074
075    /**
076     * The default threaded task listener
077     */
078    private class DefaultThreadedTaskListener extends ThreadedTaskListener<T> {
079        /**
080         * The task result to update with output
081         */
082        private final TaskResult taskResult;
083
084        public DefaultThreadedTaskListener(TaskResult taskResult) {
085            this.taskResult = taskResult;
086        }
087
088        @Override
089        public void afterBatch(SailPointContext threadContext) throws GeneralException {
090            AbstractThreadedTask.this.beforeBatch(threadContext);
091        }
092
093        @Override
094        public void beforeBatch(SailPointContext taskContext) throws GeneralException {
095            AbstractThreadedTask.this.beforeBatch(taskContext);
096        }
097
098        @Override
099        public void beforeExecution(Thread theThread, T input) {
100            if (beforeExecutionHook != null) {
101                beforeExecutionHook.accept(theThread, input);
102            }
103        }
104
105        @Override
106        public void handleException(Exception e) {
107            taskResult.addException(e);
108        }
109
110        @Override
111        public void handleFailure(T input) {
112            failureMarker.accept(input);
113        }
114
115        @Override
116        public void handleSuccess(T input) {
117            successMarker.accept(input);
118        }
119
120        @Override
121        public boolean isTerminated() {
122            return AbstractThreadedTask.this.terminated.get();
123        }
124    }
125
126    /**
127     * The batch size, which may be zero for no batching
128     */
129    private int batchSize;
130    /**
131     * If present, this BiConsumer can be invoked before execution of each object.
132     * The subclass is responsible for making this call. This is mainly useful as
133     * a testing hook.
134     */
135    protected BiConsumer<Thread, Object> beforeExecutionHook;
136    /**
137     * The parent SailPoint context
138     */
139    protected SailPointContext context;
140    /**
141     * The thread pool
142     */
143    protected ExecutorService executor;
144    /**
145     * The counter of how many threads have indicated failures
146     */
147    protected AtomicInteger failureCounter;
148    /**
149     * The callback on failed execution for each item
150     */
151    private Consumer<T> failureMarker;
152    /**
153     * The log object
154     */
155    protected Log log;
156    /**
157     * The counter of how many threads have indicated success
158     */
159    protected AtomicInteger successCounter;
160    /**
161     * The callback on successful execution for each item
162     */
163    private Consumer<T> successMarker;
164    /**
165     * The TaskResult to keep updated with changes
166     */
167    protected TaskResult taskResult;
168
169    /**
170     * The TaskSchedule, which can be used in querying
171     */
172    protected TaskSchedule taskSchedule;
173    /**
174     * The boolean flag indicating that this task has been terminated
175     */
176    protected final AtomicBoolean terminated;
177    /**
178     * A set of callbacks to run on task termination
179     */
180    private final List<Functions.GenericCallback> terminationHandlers;
181    /**
182     * How many threads are to be created
183     */
184    protected int threadCount;
185
186    /**
187     * A way to override creation of the thread workers
188     */
189    private ThreadWorkerCreator<T> workerCreator;
190
191    public AbstractThreadedTask() {
192        this.terminated = new AtomicBoolean(false);
193        this.successCounter = new AtomicInteger(0);
194        this.failureCounter = new AtomicInteger(0);
195        this.terminationHandlers = new ArrayList<>();
196    }
197
198    /**
199     * Adds a termination handler to this execution of the task
200     * @param handler The termination handler to run on completion
201     */
202    protected final void addTerminationHandler(Functions.GenericCallback handler) {
203        this.terminationHandlers.add(handler);
204    }
205
206    /**
207     * Invoked by the default worker thread after each batch is completed.
208     * This can be overridden by a subclass to do arbitrary cleanup.
209     *
210     * @param context The context for this thread
211     * @throws GeneralException if anything goes wrong
212     */
213    public void afterBatch(SailPointContext context) throws GeneralException {
214        /* No-op by default */
215    }
216
217    /**
218     * Invoked after completion of all threads, even if they fail
219     * @param context The SailPoint context
220     */
221    protected void afterCompletion(SailPointContext context) {
222        /* No-op by default */
223    }
224
225    /**
226     * Invoked by the default worker thread before each batch is begun. If this
227     * method throws an exception, the batch worker ought to prevent the batch
228     * from being executed.
229     *
230     * @param context The context for this thread
231     * @throws GeneralException if any failures occur
232     */
233    public void beforeBatch(SailPointContext context) throws GeneralException {
234        /* No-op by default */
235    }
236
237    /**
238     * Retrieves an iterator over batches of items, with the size suggested by the second
239     * parameter. If left unmodified, returns either a {@link BatchingIterator} when the
240     * batch size is greater than 1, or a {@link TransformingIterator} that constructs a
241     * singleton list for each item when batch size is 1.
242     *
243     * If possible, the returned Iterator should simply wrap the input, rather than
244     * consuming it. This allows for "live" iterators that read from a data source
245     * directly rather than pre-reading. However, beware Hibernate iterators here
246     * because a 'commit' can kill those mid-iterate.
247     *
248     * @param items The input iterator of items
249     * @param batchSize The batch size
250     * @return The iterator over a list of items
251     */
252    protected Iterator<List<T>> createBatchIterator(Iterator<? extends T> items, int batchSize) {
253        Iterator<List<T>> batchingIterator;
254        if (batchSize > 1) {
255            // Batching iterator will combine items into lists of up to batchSize
256            batchingIterator = new BatchingIterator<>(items, batchSize);
257        } else {
258            // This iterator will just transform each item into a list containing only that item
259            batchingIterator = new TransformingIterator<T, List<T>>(items, Collections::singletonList);
260        }
261        return batchingIterator;
262    }
263
264    /**
265     * The main method of this task executor, which invokes the appropriate hook methods.
266     */
267    @Override
268    public final void execute(SailPointContext ctx, TaskSchedule ts, TaskResult tr, Attributes<String, Object> args) throws Exception {
269        TaskMonitor monitor = new TaskMonitor(ctx, tr);
270
271        this.terminated.set(false);
272        this.successCounter.set(0);
273        this.failureCounter.set(0);
274        this.terminationHandlers.clear();
275
276        this.workerCreator = ThreadExecutorWorker::new;
277        this.failureMarker = this::markFailure;
278        this.successMarker = this::markSuccess;
279
280        this.log = LogFactory.getLog(this.getClass());
281        this.context = ctx;
282        this.taskResult = tr;
283        this.taskSchedule = ts;
284
285        monitor.updateProgress("Parsing input arguments");
286        monitor.commitMasterResult();
287
288        parseArgs(args);
289
290        monitor.updateProgress("Retrieving target objects");
291        monitor.commitMasterResult();
292
293        Iterator<? extends T> items = getObjectIterator(ctx, args);
294        if (items != null) {
295            monitor.updateProgress("Processing target objects");
296            monitor.commitMasterResult();
297
298            try {
299                submitAndWait(ctx, taskResult, items);
300            } finally {
301                afterCompletion(ctx);
302            }
303
304            if (!terminated.get()) {
305                monitor.updateProgress("Invoking termination handlers");
306                monitor.commitMasterResult();
307
308                runTerminationHandlers();
309            }
310        }
311    }
312
313    /**
314     * Gets the batch size for this task. By default, this is the batch size passed
315     * as an input to the task, but this may be overridden by subclasses.
316     *
317     * @return The batch size for each thread
318     */
319    public int getBatchSize() {
320        return batchSize;
321    }
322
323    /**
324     * Gets the running executor for this task
325     * @return The executor
326     */
327    public final ExecutorService getExecutor() {
328        return this.executor;
329    }
330
331    /**
332     * Retrieves an Iterator that will produce the stream of objects to be processed
333     * in parallel. Each object produced by this Iterator will be passed in its turn
334     * to {@link #threadExecute(SailPointContext, Map, Object)} as the third parameter.
335     *
336     * IMPORTANT NOTES RE: HIBERNATE:
337     *
338     * It may be unwise to return a "live" Hibernate iterator of the sort provided by
339     * context.search here. The next read of the iterator will fail with a "Result Set
340     * Closed" exception if anything commits this context while the iterator is still
341     * being consumed. It is likely that the first worker threads will execute before
342     * the iterator is fully read.
343     *
344     * If you return a SailPointObject or any other object dependent on a Hibernate
345     * context, you will likely receive context-related errors in your worker thread
346     * unless you make an effort to re-attach the object to the thread context.
347     *
348     * TODO One option may be to pass in a private context here, but it couldn't be closed until after iteration is complete.
349     *
350     * @param context The top-level task Sailpoint context
351     * @param args The task arguments
352     * @return An iterator containing the objects to be iterated over
353     * @throws GeneralException if any failures occur
354     */
355    protected abstract Iterator<? extends T> getObjectIterator(SailPointContext context, Attributes<String, Object> args) throws GeneralException;
356
357    /**
358     * Marks this item as a failure by incrementing the failure counter. Subclasses
359     * may override this method to add additional behavior.
360     */
361    protected void markFailure(T item) {
362        failureCounter.incrementAndGet();
363    }
364
365    /**
366     * Marks this item as a success by incrementing the success counter. Subclasses
367     * may override this method to add additional behavior.
368     */
369    protected void markSuccess(T item) {
370        successCounter.incrementAndGet();
371    }
372
373    /**
374     * Extracts the thread count from the task arguments. Subclasses should override
375     * this method to extract their own arguments. You must either call super.parseArgs()
376     * in any subclass implementation of this method or set {@link #threadCount} yourself.
377     *
378     * @param args The task arguments
379     * @throws Exception if any failures occur parsing arguments
380     */
381    protected void parseArgs(Attributes<String, Object> args) throws Exception {
382        this.threadCount = Util.atoi(args.getString(CommonConstants.THREADS_ATTR));
383        if (this.threadCount < 1) {
384            this.threadCount = Util.atoi(args.getString("threadCount"));
385        }
386        if (this.threadCount < 1) {
387            this.threadCount = 1;
388        }
389
390        this.batchSize = args.getInt("batchSize", 0);
391    }
392
393    /**
394     * Prepares the thread pool executor. The default implementation simply constructs
395     * a fixed-size executor service, but subclasses may override this behavior with
396     * their own implementations.
397     *
398     * After this method is finished, the {@link #executor} attribute should be set
399     * to an {@link ExecutorService} that can accept new inputs.
400     *
401     * @throws GeneralException if any failures occur
402     */
403    protected void prepareExecutor() throws GeneralException {
404        executor = Executors.newFixedThreadPool(threadCount);
405    }
406
407    /**
408     * Runs the cleanup handler
409     */
410    private void runTerminationHandlers() {
411        if (terminationHandlers.size() > 0) {
412            try {
413                SailPointContext context = SailPointFactory.getCurrentContext();
414                for (Functions.GenericCallback handler : terminationHandlers) {
415                    try {
416                        handler.run(context);
417                    } catch(Error e) {
418                        throw e;
419                    } catch(Throwable e) {
420                        log.error("Caught an error while running termination handlers", e);
421                    }
422                }
423            } catch (GeneralException e) {
424                log.error("Caught an error while running termination handlers", e);
425            }
426        }
427    }
428
429    /**
430     * Sets the "before execution hook", an optional pluggable callback that will
431     * be invoked prior to the execution of each thread. This BiConsumer's accept()
432     * method must be thread-safe as it will be invoked in parallel.
433     *
434     * @param beforeExecutionHook An optional BiConsumer callback hook
435     */
436    public final void setBeforeExecutionHook(BiConsumer<Thread, Object> beforeExecutionHook) {
437        this.beforeExecutionHook = beforeExecutionHook;
438    }
439
440    /**
441     * Sets the failure marking callback
442     * @param failureMarker The callback invoked on item failure
443     */
444    public final void setFailureMarker(Consumer<T> failureMarker) {
445        this.failureMarker = failureMarker;
446    }
447
448    /**
449     * Sets the success marking callback
450     * @param successMarker The callback invoked on item failure
451     */
452    public final void setSuccessMarker(Consumer<T> successMarker) {
453        this.successMarker = successMarker;
454    }
455
456    /**
457     * Sets the worker creator function. This function should return a SailPointWorker
458     * extension that will take the given List of typed items and process them when
459     * its thread is invoked.
460     *
461     * @param workerCreator The worker creator function
462     */
463    public final void setWorkerCreator(ThreadWorkerCreator<T> workerCreator) {
464        this.workerCreator = workerCreator;
465    }
466
467    /**
468     * Submits the iterator of items to the thread pool, calling threadExecute for each
469     * one, then waits for all of the threads to complete or the task to be terminated.
470     *
471     * @param context The SailPoint context
472     * @param taskResult The taskResult to update (for monitoring)
473     * @param items The iterator over items being processed
474     * @throws GeneralException if any failures occur
475     */
476    protected void submitAndWait(SailPointContext context, TaskResult taskResult, Iterator<? extends T> items) throws GeneralException {
477        int batchSize = getBatchSize();
478        final TaskMonitor monitor = new TaskMonitor(context, taskResult);
479
480        // Default listener allowing individual worker state to be propagated up
481        // through the various callbacks, hooks, and listeners on this task.
482        ThreadedTaskListener<T> taskContext = new DefaultThreadedTaskListener(taskResult);
483        try {
484            prepareExecutor();
485            AtomicInteger totalCount = new AtomicInteger();
486            try {
487                Iterator<List<T>> batchingIterator = createBatchIterator(items, batchSize);
488
489                batchingIterator.forEachRemaining(listOfObjects -> {
490                    SailPointWorker worker = workerCreator.createWorker(listOfObjects, this, taskContext);
491                    worker.setMonitor(monitor);
492                    executor.submit(worker.runnable());
493                    totalCount.incrementAndGet();
494                });
495            } finally {
496                Util.flushIterator(items);
497            }
498
499            try {
500                monitor.updateProgress("Submitted " + totalCount.get() + " tasks");
501                monitor.commitMasterResult();
502            } catch(GeneralException e) {
503                /* Ignore this */
504            }
505
506            // No further items can be submitted to the executor at this point
507            executor.shutdown();
508
509            log.info("Waiting for all threads in task " + taskResult.getName() + " to terminate");
510
511            int totalItems = totalCount.get();
512            if (batchSize > 1) {
513                totalItems = totalItems * batchSize;
514            }
515
516            while(!executor.isTerminated()) {
517                executor.awaitTermination(5, TimeUnit.SECONDS);
518                int finished = this.successCounter.get() + this.failureCounter.get();
519                try {
520                    monitor.updateProgress("Completed " + finished + " of " + totalItems + " items");
521                    monitor.commitMasterResult();
522                } catch(GeneralException e) {
523                    /* Ignore this */
524                }
525            }
526
527            log.info("All threads have terminated in task " + taskResult.getName());
528
529            taskResult.setAttribute("successes", successCounter.get());
530            taskResult.setAttribute("failures", failureCounter.get());
531            monitor.commitMasterResult();
532        } catch(InterruptedException e) {
533            terminate();
534            throw new GeneralException(e);
535        }
536    }
537
538    /**
539     * Terminates the task by setting the terminated flag, interrupting the executor, waiting five seconds for it to exit, then invoking any shutdown hooks
540     *
541     * @return Always true
542     */
543    @Override
544    public final boolean terminate() {
545        if (!terminated.get()) {
546            synchronized(terminated) {
547                if (!terminated.get()) {
548                    terminated.set(true);
549                    if (executor != null && !executor.isTerminated()) {
550                        executor.shutdownNow();
551                        if (terminationHandlers.size() > 0) {
552                            try {
553                                executor.awaitTermination(2L, TimeUnit.SECONDS);
554                            } catch (InterruptedException e) {
555                                log.debug("Interrupted while waiting during termination", e);
556                            }
557                        }
558                    }
559                    runTerminationHandlers();
560                }
561            }
562        }
563        return true;
564    }
565
566    /**
567     * This method will be called in parallel for each item produced by {@link #getObjectIterator(SailPointContext, Attributes)}.
568     *
569     * DO NOT use the parent context in this method. You will encounter Weird Database Issues.
570     *
571     * @param threadContext A private IIQ context for the current JVM thread
572     * @param parameters A set of default parameters suitable for a Rule or Script. In the default implementation, the object will be in this map as 'object'.
573     * @param obj The object to process
574     * @return An arbitrary value (ignored by default)
575     * @throws GeneralException if any failures occur
576     */
577    public abstract Object threadExecute(SailPointContext threadContext, Map<String, Object> parameters, T obj) throws GeneralException;
578}