diff --git a/edx/analytics/tasks/common/spark.py b/edx/analytics/tasks/common/spark.py index 903af51814..c038143457 100644 --- a/edx/analytics/tasks/common/spark.py +++ b/edx/analytics/tasks/common/spark.py @@ -14,7 +14,7 @@ ManifestInputTargetMixin, convert_to_manifest_input_if_necessary, remove_manifest_target_if_exists ) from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin -from edx.analytics.tasks.util.url import get_target_from_url, url_path_join +from edx.analytics.tasks.util.url import UncheckedExternalURL, get_target_from_url, url_path_join _file_path_to_package_meta_path = {} @@ -163,31 +163,42 @@ class PathSelectionTaskSpark(EventLogSelectionDownstreamMixin, luigi.WrapperTask """ Path selection task with manifest feature for spark """ - requirements = None + targets = None manifest_id = luigi.Parameter( description='File name for manifest' ) + manifest_dir = luigi.Parameter( + description='Directory for manifest files' + ) + pyspark_logger = luigi.Parameter( + description='Pyspark logger', + default=None + ) def requires(self): - yield PathSelectionByDateIntervalTask( + if not self.targets: + if self.pyspark_logger: + self.pyspark_logger.warn("PathSelectionTaskSpark=> targets not found, refreshing!") + self.targets = self._get_targets() + else: + if self.pyspark_logger: + self.pyspark_logger.warn("PathSelectionTaskSpark=> targets already exist") + return self.targets + + def _get_targets(self): + input = PathSelectionByDateIntervalTask( source=self.source, interval=self.interval, pattern=self.pattern, date_pattern=self.date_pattern + ).output() + targets = luigi.task.flatten( + convert_to_manifest_input_if_necessary(self.manifest_id, input, self.manifest_dir) ) - - def get_target_paths(self): - log.warn("PathSelectionTaskSpark: checking requirements {}".format(self.manifest_id)) - if not self.requirements: - log.warn("PathSelectionTaskSpark: requirements not found, refreshing!!") - targets = luigi.task.flatten( - convert_to_manifest_input_if_necessary(self.manifest_id, self.input()) - ) - self.requirements = targets - return self.requirements + return [UncheckedExternalURL(target.path) for target in targets] def output(self): - return self.get_target_paths() + return [target.output() for target in self.requires()] class EventLogSelectionMixinSpark(EventLogSelectionDownstreamMixin): @@ -240,19 +251,26 @@ def get_log_schema(self): return event_log_schema def get_input_rdd(self, *args): - manifest_target = self.get_manifest_path(*args) - self.log.warn("PYSPARK LOGGER : Getting input rdd ---> target : {}".format(manifest_target.path)) - if manifest_target.exists(): + manifest_path = self.get_config_from_args('manifest_path', *args, default_value='') + targets = PathSelectionTaskSpark( + source=self.source, + interval=self.interval, + pattern=self.pattern, + date_pattern=self.date_pattern, + manifest_id=self.manifest_id, + manifest_dir=manifest_path, + pyspark_logger=self.log + ).output() + if len(targets) and 'manifest' in targets[0].path: # Reading manifest as rdd with spark is alot faster as compared to hadoop. # Currently, we're getting only 1 manifest file per request, so we will create a single rdd from it. # If there are multiple manifest files, each file can be read as rdd and then union it with other manifest rdds - self.log.warn("PYSPARK LOGGER: Reading manifest file :: {} ".format(manifest_target.path)) - source_rdd = self._spark.sparkContext.textFile(manifest_target.path) + self.log.warn("PYSPARK LOGGER: Reading manifest file :: {} ".format(targets[0].path)) + source_rdd = self._spark.sparkContext.textFile(targets[0].path, 1) else: # maybe we only need to broadcast it ( on cluster ) and not create rdd. lets see self.log.warn("PYSPARK LOGGER: Reading normal targets") - input_targets = luigi.task.flatten(self.input()) - source_rdd = self._spark.sparkContext.parallelize([target.path for target in input_targets]) + source_rdd = self._spark.sparkContext.parallelize([target.path for target in targets]) return source_rdd def get_event_log_dataframe(self, spark, *args, **kwargs): @@ -309,7 +327,7 @@ def manifest_id(self): 'interval': self.interval, 'pattern': self.pattern, 'date_pattern': self.date_pattern, - 'spark':'for_some_difference_with_hadoop_manifest' + 'spark': 'for_some_difference_with_hadoop_manifest' } return str(hash(frozenset(params.items()))).replace('-', 'n') @@ -322,15 +340,6 @@ def get_manifest_path(self, *args): ) ) - def requires(self): - yield PathSelectionTaskSpark( - source=self.source, - interval=self.interval, - pattern=self.pattern, - date_pattern=self.date_pattern, - manifest_id=self.manifest_id - ) - def spark_job(self): """ Spark code for the job diff --git a/edx/analytics/tasks/util/manifest.py b/edx/analytics/tasks/util/manifest.py index f8b4600a52..8cfec843fc 100644 --- a/edx/analytics/tasks/util/manifest.py +++ b/edx/analytics/tasks/util/manifest.py @@ -12,14 +12,14 @@ log = logging.getLogger(__name__) -def convert_to_manifest_input_if_necessary(manifest_id, targets): +def convert_to_manifest_input_if_necessary(manifest_id, targets, manifest_dir=None): targets = luigi.task.flatten(targets) threshold = configuration.get_config().getint(CONFIG_SECTION, 'threshold', -1) if threshold > 0 and len(targets) >= threshold: log.debug( 'Using manifest since %d inputs are greater than or equal to the threshold %d', len(targets), threshold ) - return [create_manifest_target(manifest_id, targets)] + return [create_manifest_target(manifest_id, targets, manifest_dir)] else: log.debug( 'Directly processing files since %d inputs are less than the threshold %d', len(targets), threshold @@ -27,14 +27,15 @@ def convert_to_manifest_input_if_necessary(manifest_id, targets): return targets -def get_manifest_file_path(manifest_id): +def get_manifest_file_path(manifest_id, manifest_dir=None): # Construct the manifest file URL from the manifest_id and the configuration - base_url = configuration.get_config().get(CONFIG_SECTION, 'path') - manifest_file_path = url_path_join(base_url, manifest_id + '.manifest') + if manifest_dir is None: + manifest_dir = configuration.get_config().get(CONFIG_SECTION, 'path') + manifest_file_path = url_path_join(manifest_dir, manifest_id + '.manifest') return manifest_file_path -def create_manifest_target(manifest_id, targets): +def create_manifest_target(manifest_id, targets, manifest_dir=None): # If we are running locally, we need our manifest file to be a local file target, however, if we are running on # a real Hadoop cluster, it has to be an HDFS file so that the input format can read it. Luigi makes it a little # difficult for us to construct a target that can be one or the other of those types of targets at runtime since @@ -42,7 +43,7 @@ def create_manifest_target(manifest_id, targets): # base class at runtime based on the URL of the manifest file. # Construct the manifest file URL from the manifest_id and the configuration - manifest_file_path = get_manifest_file_path(manifest_id) + manifest_file_path = get_manifest_file_path(manifest_id, manifest_dir) # Figure out the type of target that should be used to write/read the file. manifest_file_target_class, init_args, init_kwargs = get_target_class_from_url(manifest_file_path)