001package com.identityworksllc.iiq.common.task;
002
003import com.identityworksllc.iiq.common.CommonConstants;
004import com.identityworksllc.iiq.common.Functions;
005import com.identityworksllc.iiq.common.iterators.BatchingIterator;
006import com.identityworksllc.iiq.common.iterators.TransformingIterator;
007import com.identityworksllc.iiq.common.threads.SailPointWorker;
008import org.apache.commons.logging.Log;
009import org.apache.commons.logging.LogFactory;
010import sailpoint.api.SailPointContext;
011import sailpoint.api.SailPointFactory;
012import sailpoint.object.Attributes;
013import sailpoint.object.TaskResult;
014import sailpoint.object.TaskSchedule;
015import sailpoint.task.AbstractTaskExecutor;
016import sailpoint.task.TaskMonitor;
017import sailpoint.tools.GeneralException;
018import sailpoint.tools.Util;
019
020import java.util.ArrayList;
021import java.util.Collections;
022import java.util.Iterator;
023import java.util.List;
024import java.util.Map;
025import java.util.concurrent.ExecutorService;
026import java.util.concurrent.Executors;
027import java.util.concurrent.TimeUnit;
028import java.util.concurrent.atomic.AtomicBoolean;
029import java.util.concurrent.atomic.AtomicInteger;
030import java.util.function.BiConsumer;
031import java.util.function.Consumer;
032
033/**
034 * An abstract superclass for nearly all custom multi-threaded SailPoint tasks. This task
035 * will retrieve a list of objects and then pass each of them to a processor method
036 * in the subclass in parallel.
037 *
038 * The overall flow is:
039 *
040 *  1. Invoke {@link #parseArgs(Attributes)} to extract task arguments. Subclasses should
041 *     override this to retrieve their own parameters.
042 *  2. Invoke {@link #getObjectIterator(SailPointContext, Attributes)} to retrieve a list of
043 *     items to iterate over. This iterator can be "streaming" or not.
044 *  3. Invoke {@link #submitAndWait(SailPointContext, TaskResult, Iterator)} to begin processing
045 *     of the list of items. This can be replaced by a subclass, but the default flow is described
046 *     below. It is unlikely that you will need to change it.
047 *  4. Clean up the thread pool and update the TaskResult with outcomes.
048 *
049 * Submit-and-wait proceeds as follows:
050 *
051 *  1. Retrieve the batch size from {@link #getBatchSize()}
052 *  2. Create a thread pool with the specified number of threads in it.
053 *  3. For each item, invoke the subclass's {@link #threadExecute(SailPointContext, Map, T)} method,
054 *     passing the current item in the third parameter. If a batch size is set, more than one
055 *     item will be passed in a single worker thread, eliminating the need to build and destroy
056 *     lots of private contexts. This will likely be more efficient for large operations.
057 *
058 * Via the {@link SailPointWorker} class, the {@link #threadExecute(SailPointContext, Map, T)} method
059 * will also receive an appropriately thread-specific SailPointContext object that can be used without
060 * worrying about transaction length or overlap.
061 *
062 * Subclasses can override most parts of this process by extending the various protected methods.
063 *
064 * Subclasses cannot direct receive a termination notification, but can register any number of
065 * termination handlers by invoking {@link #addTerminationHandler(Functions.GenericCallback)}.
066 * This class makes a best effort to invoke all termination handlers.
067 *
068 * @param <T> The type of input object that will be passed to threadExecute.
069 *
070 * @author Devin Rosenbauer
071 * @author Instrumental Identity
072 */
073public abstract class AbstractThreadedTask<T> extends AbstractTaskExecutor implements PrivateContextObjectConsumer<T> {
074
075    /**
076     * The default threaded task listener
077     */
078    private class DefaultThreadedTaskListener extends ThreadedTaskListener<T> {
079        /**
080         * The task result to update with output
081         */
082        private final TaskResult taskResult;
083
084        public DefaultThreadedTaskListener(TaskResult taskResult) {
085            this.taskResult = taskResult;
086        }
087
088        @Override
089        public void afterBatch(SailPointContext threadContext) throws GeneralException {
090            AbstractThreadedTask.this.beforeBatch(threadContext);
091        }
092
093        @Override
094        public void beforeBatch(SailPointContext taskContext) throws GeneralException {
095            AbstractThreadedTask.this.beforeBatch(taskContext);
096        }
097
098        @Override
099        public void beforeExecution(Thread theThread, T input) {
100            if (beforeExecutionHook != null) {
101                beforeExecutionHook.accept(theThread, input);
102            }
103        }
104
105        @Override
106        public void handleException(Exception e) {
107            taskResult.addException(e);
108        }
109
110        @Override
111        public void handleFailure(T input) {
112            failureMarker.accept(input);
113        }
114
115        @Override
116        public void handleSuccess(T input) {
117            successMarker.accept(input);
118        }
119
120        @Override
121        public boolean isTerminated() {
122            return AbstractThreadedTask.this.terminated.get();
123        }
124    }
125
126    /**
127     * The batch size, which may be zero for no batching
128     */
129    private int batchSize;
130    /**
131     * If present, this BiConsumer can be invoked before execution of each object.
132     * The subclass is responsible for making this call. This is mainly useful as
133     * a testing hook.
134     */
135    protected BiConsumer<Thread, Object> beforeExecutionHook;
136    /**
137     * The parent SailPoint context
138     */
139    protected SailPointContext context;
140    /**
141     * The thread pool
142     */
143    protected ExecutorService executor;
144    /**
145     * The counter of how many threads have indicated failures
146     */
147    protected AtomicInteger failureCounter;
148    /**
149     * The callback on failed execution for each item
150     */
151    private Consumer<T> failureMarker;
152    /**
153     * The log object
154     */
155    protected Log log;
156    /**
157     * The counter of how many threads have indicated success
158     */
159    protected AtomicInteger successCounter;
160    /**
161     * The callback on successful execution for each item
162     */
163    private Consumer<T> successMarker;
164    /**
165     * The TaskResult to keep updated with changes
166     */
167    protected TaskResult taskResult;
168
169    /**
170     * The TaskSchedule, which can be used in querying
171     */
172    protected TaskSchedule taskSchedule;
173    /**
174     * The boolean flag indicating that this task has been terminated
175     */
176    protected final AtomicBoolean terminated;
177    /**
178     * A set of callbacks to run on task termination
179     */
180    private final List<Functions.GenericCallback> terminationHandlers;
181    /**
182     * How many threads are to be created
183     */
184    protected int threadCount;
185
186    /**
187     * A way to override creation of the thread workers
188     */
189    private ThreadWorkerCreator<T> workerCreator;
190
191    public AbstractThreadedTask() {
192        this.terminated = new AtomicBoolean(false);
193        this.successCounter = new AtomicInteger(0);
194        this.failureCounter = new AtomicInteger(0);
195        this.terminationHandlers = new ArrayList<>();
196    }
197
198    /**
199     * Adds a termination handler to this execution of the task
200     * @param handler The termination handler to run on completion
201     */
202    protected final void addTerminationHandler(Functions.GenericCallback handler) {
203        this.terminationHandlers.add(handler);
204    }
205
206    /**
207     * Invoked by the default worker thread after each batch is completed.
208     * This can be overridden by a subclass to do arbitrary cleanup.
209     *
210     * @param context The context for this thread
211     * @throws GeneralException if anything goes wrong
212     */
213    public void afterBatch(SailPointContext context) throws GeneralException {
214        /* No-op by default */
215    }
216
217    /**
218     * Invoked by the default worker thread before each batch is begun. If this
219     * method throws an exception, the batch worker ought to prevent the batch
220     * from being executed.
221     *
222     * @param context The context for this thread
223     * @throws GeneralException if any failures occur
224     */
225    public void beforeBatch(SailPointContext context) throws GeneralException {
226        /* No-op by default */
227    }
228
229    /**
230     * The main method of this task executor, which invokes the appropriate hook methods.
231     */
232    @Override
233    public final void execute(SailPointContext ctx, TaskSchedule ts, TaskResult tr, Attributes<String, Object> args) throws Exception {
234        TaskMonitor monitor = new TaskMonitor(ctx, tr);
235
236        this.terminated.set(false);
237        this.successCounter.set(0);
238        this.failureCounter.set(0);
239        this.terminationHandlers.clear();
240
241        this.workerCreator = ThreadExecutorWorker::new;
242        this.failureMarker = this::markFailure;
243        this.successMarker = this::markSuccess;
244
245        this.log = LogFactory.getLog(this.getClass());
246        this.context = ctx;
247        this.taskResult = tr;
248        this.taskSchedule = ts;
249
250        monitor.updateProgress("Parsing input arguments");
251        monitor.commitMasterResult();
252
253        parseArgs(args);
254
255        monitor.updateProgress("Retrieving target objects");
256        monitor.commitMasterResult();
257
258        Iterator<? extends T> items = getObjectIterator(ctx, args);
259        if (items != null) {
260            monitor.updateProgress("Processing target objects");
261            monitor.commitMasterResult();
262
263            submitAndWait(ctx, taskResult, items);
264
265            if (!terminated.get()) {
266                monitor.updateProgress("Invoking termination handlers");
267                monitor.commitMasterResult();
268
269                runTerminationHandlers();
270            }
271        }
272    }
273
274    /**
275     * Retrieves an iterator over batches of items, with the size suggested by the second
276     * parameter. If left unmodified, returns either a {@link BatchingIterator} when the
277     * batch size is greater than 1, or a {@link TransformingIterator} that constructs a
278     * singleton list for each item when batch size is 1.
279     *
280     * If possible, the returned Iterator should simply wrap the input, rather than
281     * consuming it. This allows for "live" iterators that read from a data source
282     * directly rather than pre-reading. However, beware Hibernate iterators here
283     * because a 'commit' can kill those mid-iterate.
284     *
285     * @param items The input iterator of items
286     * @param batchSize The batch size
287     * @return The iterator over a list of items
288     */
289    protected Iterator<List<T>> createBatchIterator(Iterator<? extends T> items, int batchSize) {
290        Iterator<List<T>> batchingIterator;
291        if (batchSize > 1) {
292            // Batching iterator will combine items into lists of up to batchSize
293            batchingIterator = new BatchingIterator<>(items, batchSize);
294        } else {
295            // This iterator will just transform each item into a list containing only that item
296            batchingIterator = new TransformingIterator<T, List<T>>(items, Collections::singletonList);
297        }
298        return batchingIterator;
299    }
300
301    /**
302     * Gets the batch size for this task. By default, this is the batch size passed
303     * as an input to the task, but this may be overridden by subclasses.
304     *
305     * @return The batch size for each thread
306     */
307    public int getBatchSize() {
308        return batchSize;
309    }
310
311    /**
312     * Gets the running executor for this task
313     * @return The executor
314     */
315    public final ExecutorService getExecutor() {
316        return this.executor;
317    }
318
319    /**
320     * Retrieves an Iterator that will produce the stream of objects to be processed
321     * in parallel. Each object produced by this Iterator will be passed in its turn
322     * to {@link #threadExecute(SailPointContext, Map, Object)} as the third parameter.
323     *
324     * IMPORTANT NOTES RE: HIBERNATE:
325     *
326     * It may be unwise to return a "live" Hibernate iterator of the sort provided by
327     * context.search here. The next read of the iterator will fail with a "Result Set
328     * Closed" exception if anything commits this context while the iterator is still
329     * being consumed. It is likely that the first worker threads will execute before
330     * the iterator is fully read.
331     *
332     * If you return a SailPointObject or any other object dependent on a Hibernate
333     * context, you will likely receive context-related errors in your worker thread
334     * unless you make an effort to re-attach the object to the thread context.
335     *
336     * TODO One option may be to pass in a private context here, but it couldn't be closed until after iteration is complete.
337     *
338     * @param context The top-level task Sailpoint context
339     * @param args The task arguments
340     * @return An iterator containing the objects to be iterated over
341     * @throws GeneralException if any failures occur
342     */
343    protected abstract Iterator<? extends T> getObjectIterator(SailPointContext context, Attributes<String, Object> args) throws GeneralException;
344
345    /**
346     * Marks this item as a failure by incrementing the failure counter. Subclasses
347     * may override this method to add additional behavior.
348     */
349    protected void markFailure(T item) {
350        failureCounter.incrementAndGet();
351    }
352
353    /**
354     * Marks this item as a success by incrementing the success counter. Subclasses
355     * may override this method to add additional behavior.
356     */
357    protected void markSuccess(T item) {
358        successCounter.incrementAndGet();
359    }
360
361    /**
362     * Extracts the thread count from the task arguments. Subclasses should override
363     * this method to extract their own arguments. You must either call super.parseArgs()
364     * in any subclass implementation of this method or set {@link #threadCount} yourself.
365     *
366     * @param args The task arguments
367     * @throws Exception if any failures occur parsing arguments
368     */
369    protected void parseArgs(Attributes<String, Object> args) throws Exception {
370        this.threadCount = Util.atoi(args.getString(CommonConstants.THREADS_ATTR));
371        if (this.threadCount < 1) {
372            this.threadCount = Util.atoi(args.getString("threadCount"));
373        }
374        if (this.threadCount < 1) {
375            this.threadCount = 1;
376        }
377
378        this.batchSize = args.getInt("batchSize", 0);
379    }
380
381    /**
382     * Prepares the thread pool executor. The default implementation simply constructs
383     * a fixed-size executor service, but subclasses may override this behavior with
384     * their own implementations.
385     *
386     * After this method is finished, the {@link #executor} attribute should be set
387     * to an {@link ExecutorService} that can accept new inputs.
388     *
389     * @throws GeneralException if any failures occur
390     */
391    protected void prepareExecutor() throws GeneralException {
392        executor = Executors.newFixedThreadPool(threadCount);
393    }
394
395    /**
396     * Runs the cleanup handler
397     */
398    private void runTerminationHandlers() {
399        if (terminationHandlers.size() > 0) {
400            try {
401                SailPointContext context = SailPointFactory.getCurrentContext();
402                for (Functions.GenericCallback handler : terminationHandlers) {
403                    try {
404                        handler.run(context);
405                    } catch(Error e) {
406                        throw e;
407                    } catch(Throwable e) {
408                        log.error("Caught an error while running termination handlers", e);
409                    }
410                }
411            } catch (GeneralException e) {
412                log.error("Caught an error while running termination handlers", e);
413            }
414        }
415    }
416
417    /**
418     * Sets the "before execution hook", an optional pluggable callback that will
419     * be invoked prior to the execution of each thread. This BiConsumer's accept()
420     * method must be thread-safe as it will be invoked in parallel.
421     *
422     * @param beforeExecutionHook An optional BiConsumer callback hook
423     */
424    public final void setBeforeExecutionHook(BiConsumer<Thread, Object> beforeExecutionHook) {
425        this.beforeExecutionHook = beforeExecutionHook;
426    }
427
428    /**
429     * Sets the failure marking callback
430     * @param failureMarker The callback invoked on item failure
431     */
432    public final void setFailureMarker(Consumer<T> failureMarker) {
433        this.failureMarker = failureMarker;
434    }
435
436    /**
437     * Sets the success marking callback
438     * @param successMarker The callback invoked on item failure
439     */
440    public final void setSuccessMarker(Consumer<T> successMarker) {
441        this.successMarker = successMarker;
442    }
443
444    /**
445     * Sets the worker creator function. This function should return a SailPointWorker
446     * extension that will take the given List of typed items and process them when
447     * its thread is invoked.
448     *
449     * @param workerCreator The worker creator function
450     */
451    public final void setWorkerCreator(ThreadWorkerCreator<T> workerCreator) {
452        this.workerCreator = workerCreator;
453    }
454
455    /**
456     * Submits the iterator of items to the thread pool, calling threadExecute for each
457     * one, then waits for all of the threads to complete or the task to be terminated.
458     *
459     * @param context The SailPoint context
460     * @param taskResult The taskResult to update (for monitoring)
461     * @param items The iterator over items being processed
462     * @throws GeneralException if any failures occur
463     */
464    protected void submitAndWait(SailPointContext context, TaskResult taskResult, Iterator<? extends T> items) throws GeneralException {
465        int batchSize = getBatchSize();
466        final TaskMonitor monitor = new TaskMonitor(context, taskResult);
467
468        // Default listener allowing individual worker state to be propagated up
469        // through the various callbacks, hooks, and listeners on this task.
470        ThreadedTaskListener<T> taskContext = new DefaultThreadedTaskListener(taskResult);
471        try {
472            prepareExecutor();
473            AtomicInteger totalCount = new AtomicInteger();
474            try {
475                Iterator<List<T>> batchingIterator = createBatchIterator(items, batchSize);
476
477                batchingIterator.forEachRemaining(listOfObjects -> {
478                    SailPointWorker worker = workerCreator.createWorker(listOfObjects, this, taskContext);
479                    worker.setMonitor(monitor);
480                    executor.submit(worker.runnable());
481                    totalCount.incrementAndGet();
482                });
483            } finally {
484                Util.flushIterator(items);
485            }
486
487            try {
488                monitor.updateProgress("Submitted " + totalCount.get() + " tasks");
489                monitor.commitMasterResult();
490            } catch(GeneralException e) {
491                /* Ignore this */
492            }
493
494            // No further items can be submitted to the executor at this point
495            executor.shutdown();
496
497            log.info("Waiting for all threads in task " + taskResult.getName() + " to terminate");
498
499            int totalItems = totalCount.get();
500            if (batchSize > 1) {
501                totalItems = totalItems * batchSize;
502            }
503
504            while(!executor.isTerminated()) {
505                executor.awaitTermination(5, TimeUnit.SECONDS);
506                int finished = this.successCounter.get() + this.failureCounter.get();
507                try {
508                    monitor.updateProgress("Completed " + finished + " of " + totalItems + " items");
509                    monitor.commitMasterResult();
510                } catch(GeneralException e) {
511                    /* Ignore this */
512                }
513            }
514
515            log.info("All threads have terminated in task " + taskResult.getName());
516
517            taskResult.setAttribute("successes", successCounter.get());
518            taskResult.setAttribute("failures", failureCounter.get());
519            monitor.commitMasterResult();
520        } catch(InterruptedException e) {
521            terminate();
522            throw new GeneralException(e);
523        }
524    }
525
526    /**
527     * Terminates the task by setting the terminated flag, interrupting the executor, waiting five seconds for it to exit, then invoking any shutdown hooks
528     *
529     * @return Always true
530     */
531    @Override
532    public final boolean terminate() {
533        if (!terminated.get()) {
534            synchronized(terminated) {
535                if (!terminated.get()) {
536                    terminated.set(true);
537                    if (executor != null && !executor.isTerminated()) {
538                        executor.shutdownNow();
539                        if (terminationHandlers.size() > 0) {
540                            try {
541                                executor.awaitTermination(2L, TimeUnit.SECONDS);
542                            } catch (InterruptedException e) {
543                                log.debug("Interrupted while waiting during termination", e);
544                            }
545                        }
546                    }
547                    runTerminationHandlers();
548                }
549            }
550        }
551        return true;
552    }
553
554    /**
555     * This method will be called in parallel for each item produced by {@link #getObjectIterator(SailPointContext, Attributes)}.
556     *
557     * DO NOT use the parent context in this method. You will encounter Weird Database Issues.
558     *
559     * @param threadContext A private IIQ context for the current JVM thread
560     * @param parameters A set of default parameters suitable for a Rule or Script. In the default implementation, the object will be in this map as 'object'.
561     * @param obj The object to process
562     * @return An arbitrary value (ignored by default)
563     * @throws GeneralException if any failures occur
564     */
565    public abstract Object threadExecute(SailPointContext threadContext, Map<String, Object> parameters, T obj) throws GeneralException;
566}