001package com.identityworksllc.iiq.common.task;
002
003import com.identityworksllc.iiq.common.CommonConstants;
004import com.identityworksllc.iiq.common.Functions;
005import com.identityworksllc.iiq.common.iterators.BatchingIterator;
006import com.identityworksllc.iiq.common.iterators.TransformingIterator;
007import com.identityworksllc.iiq.common.threads.SailPointWorker;
008import org.apache.commons.logging.Log;
009import org.apache.commons.logging.LogFactory;
010import sailpoint.api.SailPointContext;
011import sailpoint.api.SailPointFactory;
012import sailpoint.object.Attributes;
013import sailpoint.object.TaskResult;
014import sailpoint.object.TaskSchedule;
015import sailpoint.task.AbstractTaskExecutor;
016import sailpoint.task.TaskMonitor;
017import sailpoint.tools.GeneralException;
018import sailpoint.tools.Util;
019
020import java.util.ArrayList;
021import java.util.Collections;
022import java.util.Iterator;
023import java.util.List;
024import java.util.Map;
025import java.util.concurrent.ExecutorService;
026import java.util.concurrent.Executors;
027import java.util.concurrent.TimeUnit;
028import java.util.concurrent.atomic.AtomicBoolean;
029import java.util.concurrent.atomic.AtomicInteger;
030import java.util.function.BiConsumer;
031import java.util.function.Consumer;
032
033/**
034 * An abstract superclass for nearly all custom multi-threaded SailPoint tasks. This task
035 * will retrieve a list of objects and then pass each of them to a processor method
036 * in the subclass in parallel.
037 *
038 * The overall flow is:
039 *
040 *  1. Invoke {@link #parseArgs(Attributes)} to extract task arguments. Subclasses should
041 *     override this to retrieve their own parameters.
042 *  2. Invoke {@link #getObjectIterator(SailPointContext, Attributes)} to retrieve a list of
043 *     items to iterate over. This iterator can be "streaming" or not.
044 *  3. Invoke {@link #submitAndWait(SailPointContext, TaskResult, Iterator)} to begin processing
045 *     of the list of items. This can be replaced by a subclass, but the default flow is described
046 *     below. It is unlikely that you will need to change it.
047 *  4. Clean up the thread pool and update the TaskResult with outcomes.
048 *
049 * Submit-and-wait proceeds as follows:
050 *
051 *  1. Retrieve the batch size from {@link #getBatchSize()}
052 *  2. Create a thread pool with the specified number of threads in it.
053 *  3. For each item, invoke the subclass's {@link #threadExecute(SailPointContext, Map, T)} method,
054 *     passing the current item in the third parameter. If a batch size is set, more than one
055 *     item will be passed in a single worker thread, eliminating the need to build and destroy
056 *     lots of private contexts. This will likely be more efficient for large operations.
057 *
058 * Via the {@link SailPointWorker} class, the {@link #threadExecute(SailPointContext, Map, T)} method
059 * will also receive an appropriately thread-specific SailPointContext object that can be used without
060 * worrying about transaction length or overlap.
061 *
062 * Subclasses can override most parts of this process by extending the various protected methods.
063 *
064 * Subclasses cannot direct receive a termination notification, but can register any number of
065 * termination handlers by invoking {@link #addTerminationHandler(Functions.GenericCallback)}.
066 * This class makes a best effort to invoke all termination handlers.
067 *
068 * @param <T> The type of input object that will be passed to threadExecute.
069 */
070public abstract class AbstractThreadedTask<T> extends AbstractTaskExecutor implements PrivateContextObjectConsumer<T> {
071
072    /**
073     * The default threaded task listener
074     */
075    private class DefaultThreadedTaskListener extends ThreadedTaskListener<T> {
076        /**
077         * The task result to update with output
078         */
079        private final TaskResult taskResult;
080
081        public DefaultThreadedTaskListener(TaskResult taskResult) {
082            this.taskResult = taskResult;
083        }
084
085        @Override
086        public void afterBatch(SailPointContext threadContext) throws GeneralException {
087            AbstractThreadedTask.this.beforeBatch(threadContext);
088        }
089
090        @Override
091        public void beforeBatch(SailPointContext taskContext) throws GeneralException {
092            AbstractThreadedTask.this.beforeBatch(taskContext);
093        }
094
095        @Override
096        public void beforeExecution(Thread theThread, T input) {
097            if (beforeExecutionHook != null) {
098                beforeExecutionHook.accept(theThread, input);
099            }
100        }
101
102        @Override
103        public void handleException(Exception e) {
104            taskResult.addException(e);
105        }
106
107        @Override
108        public void handleFailure(T input) {
109            failureMarker.accept(input);
110        }
111
112        @Override
113        public void handleSuccess(T input) {
114            successMarker.accept(input);
115        }
116
117        @Override
118        public boolean isTerminated() {
119            return AbstractThreadedTask.this.terminated.get();
120        }
121    }
122
123    /**
124     * The batch size, which may be zero for no batching
125     */
126    private int batchSize;
127    /**
128     * If present, this BiConsumer can be invoked before execution of each object.
129     * The subclass is responsible for making this call. This is mainly useful as
130     * a testing hook.
131     */
132    protected BiConsumer<Thread, Object> beforeExecutionHook;
133    /**
134     * The parent SailPoint context
135     */
136    protected SailPointContext context;
137    /**
138     * The thread pool
139     */
140    protected ExecutorService executor;
141    /**
142     * The counter of how many threads have indicated failures
143     */
144    protected AtomicInteger failureCounter;
145    /**
146     * The callback on failed execution for each item
147     */
148    private Consumer<T> failureMarker;
149    /**
150     * The log object
151     */
152    protected Log log;
153    /**
154     * The counter of how many threads have indicated success
155     */
156    protected AtomicInteger successCounter;
157    /**
158     * The callback on successful execution for each item
159     */
160    private Consumer<T> successMarker;
161    /**
162     * The TaskResult to keep updated with changes
163     */
164    protected TaskResult taskResult;
165
166    /**
167     * The TaskSchedule, which can be used in querying
168     */
169    protected TaskSchedule taskSchedule;
170    /**
171     * The boolean flag indicating that this task has been terminated
172     */
173    protected final AtomicBoolean terminated;
174    /**
175     * A set of callbacks to run on task termination
176     */
177    private final List<Functions.GenericCallback> terminationHandlers;
178    /**
179     * How many threads are to be created
180     */
181    protected int threadCount;
182
183    /**
184     * A way to override creation of the thread workers
185     */
186    private ThreadWorkerCreator<T> workerCreator;
187
188    public AbstractThreadedTask() {
189        this.terminated = new AtomicBoolean(false);
190        this.successCounter = new AtomicInteger(0);
191        this.failureCounter = new AtomicInteger(0);
192        this.terminationHandlers = new ArrayList<>();
193    }
194
195    /**
196     * Adds a termination handler to this execution of the task
197     * @param handler The termination handler to run on completion
198     */
199    protected final void addTerminationHandler(Functions.GenericCallback handler) {
200        this.terminationHandlers.add(handler);
201    }
202
203    /**
204     * Invoked by the default worker thread after each batch is completed.
205     * This can be overridden by a subclass to do arbitrary cleanup.
206     *
207     * @param context The context for this thread
208     * @throws GeneralException if anything goes wrong
209     */
210    public void afterBatch(SailPointContext context) throws GeneralException {
211        /* No-op by default */
212    }
213
214    /**
215     * Invoked by the default worker thread before each batch is begun. If this
216     * method throws an exception, the batch worker ought to prevent the batch
217     * from being executed.
218     *
219     * @param context The context for this thread
220     * @throws GeneralException if any failures occur
221     */
222    public void beforeBatch(SailPointContext context) throws GeneralException {
223        /* No-op by default */
224    }
225
226    /**
227     * The main method of this task executor, which invokes the appropriate hook methods.
228     */
229    @Override
230    public final void execute(SailPointContext ctx, TaskSchedule ts, TaskResult tr, Attributes<String, Object> args) throws Exception {
231        TaskMonitor monitor = new TaskMonitor(ctx, tr);
232
233        this.terminated.set(false);
234        this.successCounter.set(0);
235        this.failureCounter.set(0);
236        this.terminationHandlers.clear();
237
238        this.workerCreator = ThreadExecutorWorker::new;
239        this.failureMarker = this::markFailure;
240        this.successMarker = this::markSuccess;
241
242        this.log = LogFactory.getLog(this.getClass());
243        this.context = ctx;
244        this.taskResult = tr;
245        this.taskSchedule = ts;
246
247        monitor.updateProgress("Parsing input arguments");
248        monitor.commitMasterResult();
249
250        parseArgs(args);
251
252        monitor.updateProgress("Retrieving target objects");
253        monitor.commitMasterResult();
254
255        Iterator<? extends T> items = getObjectIterator(ctx, args);
256        if (items != null) {
257            monitor.updateProgress("Processing target objects");
258            monitor.commitMasterResult();
259
260            submitAndWait(ctx, taskResult, items);
261
262            if (!terminated.get()) {
263                monitor.updateProgress("Invoking termination handlers");
264                monitor.commitMasterResult();
265
266                runTerminationHandlers();
267            }
268        }
269    }
270
271    /**
272     * Retrieves an iterator over batches of items, with the size suggested by the second
273     * parameter. If left unmodified, returns either a {@link BatchingIterator} when the
274     * batch size is greater than 1, or a {@link TransformingIterator} that constructs a
275     * singleton list for each item when batch size is 1.
276     *
277     * If possible, the returned Iterator should simply wrap the input, rather than
278     * consuming it. This allows for "live" iterators that read from a data source
279     * directly rather than pre-reading. However, beware Hibernate iterators here
280     * because a 'commit' can kill those mid-iterate.
281     *
282     * @param items The input iterator of items
283     * @param batchSize The batch size
284     * @return The iterator over a list of items
285     */
286    protected Iterator<List<T>> createBatchIterator(Iterator<? extends T> items, int batchSize) {
287        Iterator<List<T>> batchingIterator;
288        if (batchSize > 1) {
289            // Batching iterator will combine items into lists of up to batchSize
290            batchingIterator = new BatchingIterator<>(items, batchSize);
291        } else {
292            // This iterator will just transform each item into a list containing only that item
293            batchingIterator = new TransformingIterator<T, List<T>>(items, Collections::singletonList);
294        }
295        return batchingIterator;
296    }
297
298    /**
299     * Gets the batch size for this task. By default, this is the batch size passed
300     * as an input to the task, but this may be overridden by subclasses.
301     *
302     * @return The batch size for each thread
303     */
304    public int getBatchSize() {
305        return batchSize;
306    }
307
308    /**
309     * Gets the running executor for this task
310     * @return The executor
311     */
312    public final ExecutorService getExecutor() {
313        return this.executor;
314    }
315
316    /**
317     * Retrieves an Iterator that will produce the stream of objects to be processed
318     * in parallel. Each object produced by this Iterator will be passed in its turn
319     * to {@link #threadExecute(SailPointContext, Map, Object)} as the third parameter.
320     *
321     * IMPORTANT NOTES RE: HIBERNATE:
322     *
323     * It may be unwise to return a "live" Hibernate iterator of the sort provided by
324     * context.search here. The next read of the iterator will fail with a "Result Set
325     * Closed" exception if anything commits this context while the iterator is still
326     * being consumed. It is likely that the first worker threads will execute before
327     * the iterator is fully read.
328     *
329     * If you return a SailPointObject or any other object dependent on a Hibernate
330     * context, you will likely receive context-related errors in your worker thread
331     * unless you make an effort to re-attach the object to the thread context.
332     *
333     * TODO One option may be to pass in a private context here, but it couldn't be closed until after iteration is complete.
334     *
335     * @param context The top-level task Sailpoint context
336     * @param args The task arguments
337     * @return An iterator containing the objects to be iterated over
338     * @throws GeneralException if any failures occur
339     */
340    protected abstract Iterator<? extends T> getObjectIterator(SailPointContext context, Attributes<String, Object> args) throws GeneralException;
341
342    /**
343     * Marks this item as a failure by incrementing the failure counter. Subclasses
344     * may override this method to add additional behavior.
345     */
346    protected void markFailure(T item) {
347        failureCounter.incrementAndGet();
348    }
349
350    /**
351     * Marks this item as a success by incrementing the success counter. Subclasses
352     * may override this method to add additional behavior.
353     */
354    protected void markSuccess(T item) {
355        successCounter.incrementAndGet();
356    }
357
358    /**
359     * Extracts the thread count from the task arguments. Subclasses should override
360     * this method to extract their own arguments. You must either call super.parseArgs()
361     * in any subclass implementation of this method or set {@link #threadCount} yourself.
362     *
363     * @param args The task arguments
364     * @throws Exception if any failures occur parsing arguments
365     */
366    protected void parseArgs(Attributes<String, Object> args) throws Exception {
367        this.threadCount = Util.atoi(args.getString(CommonConstants.THREADS_ATTR));
368        if (this.threadCount < 1) {
369            this.threadCount = Util.atoi(args.getString("threadCount"));
370        }
371        if (this.threadCount < 1) {
372            this.threadCount = 1;
373        }
374
375        this.batchSize = args.getInt("batchSize", 0);
376    }
377
378    /**
379     * Prepares the thread pool executor. The default implementation simply constructs
380     * a fixed-size executor service, but subclasses may override this behavior with
381     * their own implementations.
382     *
383     * After this method is finished, the {@link #executor} attribute should be set
384     * to an {@link ExecutorService} that can accept new inputs.
385     *
386     * @throws GeneralException if any failures occur
387     */
388    protected void prepareExecutor() throws GeneralException {
389        executor = Executors.newFixedThreadPool(threadCount);
390    }
391
392    /**
393     * Runs the cleanup handler
394     */
395    private void runTerminationHandlers() {
396        if (terminationHandlers.size() > 0) {
397            try {
398                SailPointContext context = SailPointFactory.getCurrentContext();
399                for (Functions.GenericCallback handler : terminationHandlers) {
400                    try {
401                        handler.run(context);
402                    } catch(Error e) {
403                        throw e;
404                    } catch(Throwable e) {
405                        log.error("Caught an error while running termination handlers", e);
406                    }
407                }
408            } catch (GeneralException e) {
409                log.error("Caught an error while running termination handlers", e);
410            }
411        }
412    }
413
414    /**
415     * Sets the "before execution hook", an optional pluggable callback that will
416     * be invoked prior to the execution of each thread. This BiConsumer's accept()
417     * method must be thread-safe as it will be invoked in parallel.
418     *
419     * @param beforeExecutionHook An optional BiConsumer callback hook
420     */
421    public final void setBeforeExecutionHook(BiConsumer<Thread, Object> beforeExecutionHook) {
422        this.beforeExecutionHook = beforeExecutionHook;
423    }
424
425    /**
426     * Sets the failure marking callback
427     * @param failureMarker The callback invoked on item failure
428     */
429    public final void setFailureMarker(Consumer<T> failureMarker) {
430        this.failureMarker = failureMarker;
431    }
432
433    /**
434     * Sets the success marking callback
435     * @param successMarker The callback invoked on item failure
436     */
437    public final void setSuccessMarker(Consumer<T> successMarker) {
438        this.successMarker = successMarker;
439    }
440
441    /**
442     * Sets the worker creator function. This function should return a SailPointWorker
443     * extension that will take the given List of typed items and process them when
444     * its thread is invoked.
445     *
446     * @param workerCreator The worker creator function
447     */
448    public final void setWorkerCreator(ThreadWorkerCreator<T> workerCreator) {
449        this.workerCreator = workerCreator;
450    }
451
452    /**
453     * Submits the iterator of items to the thread pool, calling threadExecute for each
454     * one, then waits for all of the threads to complete or the task to be terminated.
455     *
456     * @param context The SailPoint context
457     * @param taskResult The taskResult to update (for monitoring)
458     * @param items The iterator over items being processed
459     * @throws GeneralException if any failures occur
460     */
461    protected void submitAndWait(SailPointContext context, TaskResult taskResult, Iterator<? extends T> items) throws GeneralException {
462        int batchSize = getBatchSize();
463        final TaskMonitor monitor = new TaskMonitor(context, taskResult);
464
465        // Default listener allowing individual worker state to be propagated up
466        // through the various callbacks, hooks, and listeners on this task.
467        ThreadedTaskListener<T> taskContext = new DefaultThreadedTaskListener(taskResult);
468        try {
469            prepareExecutor();
470            AtomicInteger totalCount = new AtomicInteger();
471            try {
472                Iterator<List<T>> batchingIterator = createBatchIterator(items, batchSize);
473
474                batchingIterator.forEachRemaining(listOfObjects -> {
475                    SailPointWorker worker = workerCreator.createWorker(listOfObjects, this, taskContext);
476                    worker.setMonitor(monitor);
477                    executor.submit(worker.runnable());
478                    totalCount.incrementAndGet();
479                });
480            } finally {
481                Util.flushIterator(items);
482            }
483
484            try {
485                monitor.updateProgress("Submitted " + totalCount.get() + " tasks");
486                monitor.commitMasterResult();
487            } catch(GeneralException e) {
488                /* Ignore this */
489            }
490
491            // No further items can be submitted to the executor at this point
492            executor.shutdown();
493
494            log.info("Waiting for all threads in task " + taskResult.getName() + " to terminate");
495
496            int totalItems = totalCount.get();
497            if (batchSize > 1) {
498                totalItems = totalItems * batchSize;
499            }
500
501            while(!executor.isTerminated()) {
502                executor.awaitTermination(5, TimeUnit.SECONDS);
503                int finished = this.successCounter.get() + this.failureCounter.get();
504                try {
505                    monitor.updateProgress("Completed " + finished + " of " + totalItems + " items");
506                    monitor.commitMasterResult();
507                } catch(GeneralException e) {
508                    /* Ignore this */
509                }
510            }
511
512            log.info("All threads have terminated in task " + taskResult.getName());
513
514            taskResult.setAttribute("successes", successCounter.get());
515            taskResult.setAttribute("failures", failureCounter.get());
516            monitor.commitMasterResult();
517        } catch(InterruptedException e) {
518            terminate();
519            throw new GeneralException(e);
520        }
521    }
522
523    /**
524     * Terminates the task by setting the terminated flag, interrupting the executor, waiting five seconds for it to exit, then invoking any shutdown hooks
525     *
526     * @return Always true
527     */
528    @Override
529    public final boolean terminate() {
530        if (!terminated.get()) {
531            synchronized(terminated) {
532                if (!terminated.get()) {
533                    terminated.set(true);
534                    if (executor != null && !executor.isTerminated()) {
535                        executor.shutdownNow();
536                        if (terminationHandlers.size() > 0) {
537                            try {
538                                executor.awaitTermination(2L, TimeUnit.SECONDS);
539                            } catch (InterruptedException e) {
540                                log.debug("Interrupted while waiting during termination", e);
541                            }
542                        }
543                    }
544                    runTerminationHandlers();
545                }
546            }
547        }
548        return true;
549    }
550
551    /**
552     * This method will be called in parallel for each item produced by {@link #getObjectIterator(SailPointContext, Attributes)}.
553     *
554     * DO NOT use the parent context in this method. You will encounter Weird Database Issues.
555     *
556     * @param threadContext A private IIQ context for the current JVM thread
557     * @param parameters A set of default parameters suitable for a Rule or Script. In the default implementation, the object will be in this map as 'object'.
558     * @param obj The object to process
559     * @return An arbitrary value (ignored by default)
560     * @throws GeneralException if any failures occur
561     */
562    public abstract Object threadExecute(SailPointContext threadContext, Map<String, Object> parameters, T obj) throws GeneralException;
563}