Module hela.datasets.spark_parquet_dataset

Expand source code
from hela import BaseDataset
from hela._column_classes import _ColumnType
from datetime import date
from typing import Optional, Union, Sequence, Set
from pathlib import Path
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import first


def spark_session() -> SparkSession:
    return (
        SparkSession.builder
        .appName('testing')
        .master('local[2]')
        .getOrCreate()
    )


class SparkParquetDataset(BaseDataset):
    def __init__(
        self,
        name: str,
        folder: Optional[Union[str, Path]] = None,
        description: Optional[str] = None,
        rich_description_path: Optional[str] = None,
        partition_cols: Optional[Sequence[str]] = None,
        columns: Optional[Sequence[_ColumnType]] = None
    ) -> None:
        super().__init__(
            name=name,
            data_type='parquet',
            folder=folder,
            description=description,
            rich_description_path=rich_description_path,
            columns=columns,
            partition_cols=partition_cols
        )

    def spark(self) -> SparkSession:
        return spark_session()

    def write(self, df: DataFrame) -> int:
        df.write.mode('overwrite').parquet(str(self.path), partitionBy=self.partition_cols)
        return df.count()

    def load(self):
        return self.spark().read.parquet(str(self.path))

    def get_dates(self) -> Set[date]:
        return set([
            date.fromisoformat(p.stem.split('=')[-1])
            for p in self.path.iterdir()
            if len(p.stem.split('=')) > 1
        ])

    def get_samples(self) -> dict:
        df = self.load().limit(1000)
        return df.select([first(x, ignorenulls=True).alias(x) for x in df.columns]).first().asDict()

Functions

def spark_session() ‑> pyspark.sql.session.SparkSession
Expand source code
def spark_session() -> SparkSession:
    return (
        SparkSession.builder
        .appName('testing')
        .master('local[2]')
        .getOrCreate()
    )

Classes

class SparkParquetDataset (name: str, folder: Union[str, pathlib.Path, ForwardRef(None)] = None, description: Optional[str] = None, rich_description_path: Optional[str] = None, partition_cols: Optional[Sequence[str]] = None, columns: Optional[Sequence[hela._column_classes._ColumnType]] = None)

Abstract Dataset class to be used when defining building your own datasets.

If you choose to build data interactivity through the data catalog, it is within your own dataset classes you would build authentication and connection logic.

For full usage of the available catalog features implement the functions BaseDataset.get_samples and BaseDataset.get_dates.

Attributes

name
The name of the dataset
data_type
The data type of the dataset e.g. "parquet" or "bigquery
description
A description of the dataset as a string
partition_cols
A list of column names to be used for partitioning as strings
rich_description_path
A path to a markdown file with possibilities for longer, more detailed descriptions. Primarily used for generated catalog web page.
columns
A list of class ColumnType objects defining the columns of the dataset
path
The path to the dataset (combination of folder and name)
Expand source code
class SparkParquetDataset(BaseDataset):
    def __init__(
        self,
        name: str,
        folder: Optional[Union[str, Path]] = None,
        description: Optional[str] = None,
        rich_description_path: Optional[str] = None,
        partition_cols: Optional[Sequence[str]] = None,
        columns: Optional[Sequence[_ColumnType]] = None
    ) -> None:
        super().__init__(
            name=name,
            data_type='parquet',
            folder=folder,
            description=description,
            rich_description_path=rich_description_path,
            columns=columns,
            partition_cols=partition_cols
        )

    def spark(self) -> SparkSession:
        return spark_session()

    def write(self, df: DataFrame) -> int:
        df.write.mode('overwrite').parquet(str(self.path), partitionBy=self.partition_cols)
        return df.count()

    def load(self):
        return self.spark().read.parquet(str(self.path))

    def get_dates(self) -> Set[date]:
        return set([
            date.fromisoformat(p.stem.split('=')[-1])
            for p in self.path.iterdir()
            if len(p.stem.split('=')) > 1
        ])

    def get_samples(self) -> dict:
        df = self.load().limit(1000)
        return df.select([first(x, ignorenulls=True).alias(x) for x in df.columns]).first().asDict()

Ancestors

  • hela._base_dataset.BaseDataset
  • abc.ABC

Methods

def get_dates(self) ‑> Set[datetime.date]

Implement this function for date inspection functionality such as BaseDataset.show_dates.

Should return a set of dates when called or None if dates for some reason could not be fetched.

Expand source code
def get_dates(self) -> Set[date]:
    return set([
        date.fromisoformat(p.stem.split('=')[-1])
        for p in self.path.iterdir()
        if len(p.stem.split('=')) > 1
    ])
def get_samples(self) ‑> dict

Implement this function for sample inspection functionality used in e.g. BaseDataset.show_columns.

Should return a dictionary of string keys for column names with samples:

>>> {'my_column': 123}

Nested columns should return names with dot-notation:

>>> {'parent_column.my_column': 123}

Or None if samples could not be fetched:

>>> None
Expand source code
def get_samples(self) -> dict:
    df = self.load().limit(1000)
    return df.select([first(x, ignorenulls=True).alias(x) for x in df.columns]).first().asDict()
def load(self)
Expand source code
def load(self):
    return self.spark().read.parquet(str(self.path))
def spark(self) ‑> pyspark.sql.session.SparkSession
Expand source code
def spark(self) -> SparkSession:
    return spark_session()
def write(self, df: pyspark.sql.dataframe.DataFrame) ‑> int
Expand source code
def write(self, df: DataFrame) -> int:
    df.write.mode('overwrite').parquet(str(self.path), partitionBy=self.partition_cols)
    return df.count()