Module hela.datasets.spark_parquet_dataset
Expand source code
from hela import BaseDataset
from hela._column_classes import _ColumnType
from datetime import date
from typing import Optional, Union, Sequence, Set
from pathlib import Path
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import first
def spark_session() -> SparkSession:
return (
SparkSession.builder
.appName('testing')
.master('local[2]')
.getOrCreate()
)
class SparkParquetDataset(BaseDataset):
def __init__(
self,
name: str,
folder: Optional[Union[str, Path]] = None,
description: Optional[str] = None,
rich_description_path: Optional[str] = None,
partition_cols: Optional[Sequence[str]] = None,
columns: Optional[Sequence[_ColumnType]] = None
) -> None:
super().__init__(
name=name,
data_type='parquet',
folder=folder,
description=description,
rich_description_path=rich_description_path,
columns=columns,
partition_cols=partition_cols
)
def spark(self) -> SparkSession:
return spark_session()
def write(self, df: DataFrame) -> int:
df.write.mode('overwrite').parquet(str(self.path), partitionBy=self.partition_cols)
return df.count()
def load(self):
return self.spark().read.parquet(str(self.path))
def get_dates(self) -> Set[date]:
return set([
date.fromisoformat(p.stem.split('=')[-1])
for p in self.path.iterdir()
if len(p.stem.split('=')) > 1
])
def get_samples(self) -> dict:
df = self.load().limit(1000)
return df.select([first(x, ignorenulls=True).alias(x) for x in df.columns]).first().asDict()
Functions
def spark_session() ‑> pyspark.sql.session.SparkSession-
Expand source code
def spark_session() -> SparkSession: return ( SparkSession.builder .appName('testing') .master('local[2]') .getOrCreate() )
Classes
class SparkParquetDataset (name: str, folder: Union[str, pathlib.Path, ForwardRef(None)] = None, description: Optional[str] = None, rich_description_path: Optional[str] = None, partition_cols: Optional[Sequence[str]] = None, columns: Optional[Sequence[hela._column_classes._ColumnType]] = None)-
Abstract Dataset class to be used when defining building your own datasets.
If you choose to build data interactivity through the data catalog, it is within your own dataset classes you would build authentication and connection logic.
For full usage of the available catalog features implement the functions
BaseDataset.get_samplesandBaseDataset.get_dates.Attributes
name- The name of the dataset
data_type- The data type of the dataset e.g. "parquet" or "bigquery
description- A description of the dataset as a string
partition_cols- A list of column names to be used for partitioning as strings
rich_description_path- A path to a markdown file with possibilities for longer, more detailed descriptions. Primarily used for generated catalog web page.
columns- A list of class ColumnType objects defining the columns of the dataset
path- The path to the dataset (combination of folder and name)
Expand source code
class SparkParquetDataset(BaseDataset): def __init__( self, name: str, folder: Optional[Union[str, Path]] = None, description: Optional[str] = None, rich_description_path: Optional[str] = None, partition_cols: Optional[Sequence[str]] = None, columns: Optional[Sequence[_ColumnType]] = None ) -> None: super().__init__( name=name, data_type='parquet', folder=folder, description=description, rich_description_path=rich_description_path, columns=columns, partition_cols=partition_cols ) def spark(self) -> SparkSession: return spark_session() def write(self, df: DataFrame) -> int: df.write.mode('overwrite').parquet(str(self.path), partitionBy=self.partition_cols) return df.count() def load(self): return self.spark().read.parquet(str(self.path)) def get_dates(self) -> Set[date]: return set([ date.fromisoformat(p.stem.split('=')[-1]) for p in self.path.iterdir() if len(p.stem.split('=')) > 1 ]) def get_samples(self) -> dict: df = self.load().limit(1000) return df.select([first(x, ignorenulls=True).alias(x) for x in df.columns]).first().asDict()Ancestors
- hela._base_dataset.BaseDataset
- abc.ABC
Methods
def get_dates(self) ‑> Set[datetime.date]-
Implement this function for date inspection functionality such as
BaseDataset.show_dates.Should return a set of dates when called or None if dates for some reason could not be fetched.
Expand source code
def get_dates(self) -> Set[date]: return set([ date.fromisoformat(p.stem.split('=')[-1]) for p in self.path.iterdir() if len(p.stem.split('=')) > 1 ]) def get_samples(self) ‑> dict-
Implement this function for sample inspection functionality used in e.g.
BaseDataset.show_columns.Should return a dictionary of string keys for column names with samples:
>>> {'my_column': 123}Nested columns should return names with dot-notation:
>>> {'parent_column.my_column': 123}Or None if samples could not be fetched:
>>> NoneExpand source code
def get_samples(self) -> dict: df = self.load().limit(1000) return df.select([first(x, ignorenulls=True).alias(x) for x in df.columns]).first().asDict() def load(self)-
Expand source code
def load(self): return self.spark().read.parquet(str(self.path)) def spark(self) ‑> pyspark.sql.session.SparkSession-
Expand source code
def spark(self) -> SparkSession: return spark_session() def write(self, df: pyspark.sql.dataframe.DataFrame) ‑> int-
Expand source code
def write(self, df: DataFrame) -> int: df.write.mode('overwrite').parquet(str(self.path), partitionBy=self.partition_cols) return df.count()