# -*- coding: utf-8 -*-
"""
Datasets are a class used to hold imported data and newly generated data.
All transformations require a dataset, which will be manipulated by
the library and eventually exported for use.
"""
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
import os
import random
import re
from json import load
from csv import DictReader
import dataset as dataset_db
from datafuzz.settings import HAS_PANDAS, HAS_NUMPY
from datafuzz.output.helpers import obj_to_output
if HAS_PANDAS:
import pandas as pd
if HAS_NUMPY:
import numpy as np
[docs]class DataSet(object):
"""
DataSet objects are used as the primary datatype for \
passing around data in datafuzz.
If pandas is installed, it will use dataframes to load \
and transform data; otherwise, it will use a list. \
You can also specify to not use pandas by passing \
keyword argument `pandas=False`.
Supported inputs are JSON and CSV files, numpy 2D arrays, \
sql queries (you must pass a `db_uri` keyword argument and \
a `query` argument), pandas DataFrames and Python lists \
(of dictionaries or lists).
Attributes:
DATA_TYPES (str): list of possible datatypes (pandas, numpy, list).
FILE_REGEX (str): regex to find file name
USE_PANDAS(bool): boolean that determines whether pandas is
installed and also OK to use (no `pandas=False`)
records (list): data records for
input (obj): initial input for dataset
(can be dataframe, list, numpy array, filename or `sql`)
output (str): output
(if specified, can be dataframe, list,
numpy array, filename or `sql`)
original (obj): copy of input which won't be modified
data_type (str): dataset datatype (pandas, numpy, list).
db_uri (str): dataset database connection string
(required only if using `sql` as input or output)
query (str): dataset database select query string
(required only if using `sql` as input)
table (str): dataset database output table
(required only if using `sql` as output)
"""
USE_PANDAS = HAS_PANDAS
USE_NUMPY = HAS_NUMPY
DATA_TYPES = ['pandas', 'numpy', 'list']
FILE_REGEX = r'file://(?P<filename>.*)'
def __init__(self, input_obj, **kwargs):
self.records = []
self.input = input_obj
self.original = input_obj
self.data_type = None
self.output = kwargs.get('output')
self.db_uri = None
self.table = None
self.query = None
self.index = -1
validate_db = False
if kwargs.get('pandas') is False:
self.USE_PANDAS = False
if isinstance(self.input, str) and self.input == 'sql':
self.db_uri = kwargs.get('db_uri')
self.query = kwargs.get('query')
validate_db = True
if isinstance(self.output, str) and self.output == 'sql':
if not self.db_uri:
self.db_uri = kwargs.get('db_uri')
self.table = kwargs.get('table')
validate_db = True
if validate_db:
self.validate_db()
self._parse_input()
self.validate_parsed()
def __len__(self):
""" Return length of self.records """
if self.data_type in ['pandas', 'numpy']:
return self.records.shape[0]
return len(self.records)
def __iter__(self):
""" Iterator object of self.records """
return self
def __next__(self):
""" Iterator object of self.records This uses `self.index`.
NOTE: index is only reset on init.
TODO:
- find best way to reset index to -1 on next loop
"""
self.index = self.index + 1
if self.index >= len(self):
raise StopIteration
if self.data_type == 'pandas':
return self.records.iloc[self.index, :]
elif self.data_type == 'numpy':
return self.records[self.index, :]
return self.records[self.index]
def __getitem__(self, idx):
""" Return rows from self.records based on index """
if self.data_type == 'pandas':
return self.records.iloc[idx, :]
elif self.data_type == 'numpy':
return self.records[idx, :]
return self.records[idx]
def _parse_input(self):
""" initialization method which will call read input
for the passed input type (in init). This will use \
the _read_$type methods to then generate the \
`self.input`, `self.data_type`, `self.records` \
and `self.original` attributes.
NOTE: Files are parsed using regex and searching \
for file://$file_name
"""
if self.USE_PANDAS and isinstance(self.input, pd.DataFrame):
self._read_pandas()
elif self.USE_NUMPY and isinstance(self.input, np.ndarray):
self._read_numpy()
elif isinstance(self.input, str):
if self.input.startswith('file:'):
if self.input.endswith('.csv'):
self._read_csv()
elif self.input.endswith('.json'):
self._read_json()
elif self.input == 'sql':
self._read_sql()
elif isinstance(self.input, list):
self._read_list()
[docs] def validate_parsed(self):
""" Validate if data was properly parsed. This tests:
- valid data types
- records properly parsed and set to self.records
- self.original exists
It will raise an exception if the validation fails.
"""
try:
assert self.original is not None
assert self.data_type in self.DATA_TYPES
assert not isinstance(self.records, str)
assert isinstance(self.records, Iterable)
except AssertionError:
raise TypeError('Unsupported input: {}'.format(self.input))
try:
assert len(self.records) > 0
except AssertionError:
raise Exception('Could not parse data for: {}'.format(self.input))
[docs] def validate_db(self):
""" Validate that proper variables are set and a connection \
can be established with the database if either \
input or output are set to `sql`.
This will raise an exception if validation fails.
"""
try:
assert self.db_uri is not None
assert self.query is not None or self.table is not None
assert dataset_db.connect(self.db_uri)
except AssertionError:
raise Exception(
'You must define a valid db_uri and ' +
'query or table to use SQL.')
def _read_pandas(self):
""" Read in pandas dataframe"""
self.original = self.input
self.data_type = 'pandas'
self.records = self.original.copy()
def _read_numpy(self):
""" Read in numpy array"""
self.original = self.input
self.data_type = 'numpy'
self.records = self.input.copy()
def _read_csv(self):
""" Read in csv to list or dataframe"""
self.original = self.input
if self.USE_PANDAS:
with open(self.input_filename, 'r') as myf:
self.input = pd.read_csv(myf)
self.data_type = 'pandas'
else:
with open(self.input_filename, 'r') as myf:
self.input = list(DictReader(myf))
self.data_type = 'list'
self.records = self.input.copy()
def _read_sql(self):
""" Read in sql to list or dataframe"""
self.original = self.input
if self.USE_PANDAS:
self.input = pd.read_sql_query(self.query, self.db_uri)
self.data_type = 'pandas'
else:
with dataset_db.connect(self.db_uri) as db:
self.input = list(db.query(self.query))
self.data_type = 'list'
self.records = self.input.copy()
def _read_json(self):
""" Read in json to list or dataframe"""
self.original = self.input
if self.USE_PANDAS:
with open(self.input_filename, 'r') as myf:
self.input = pd.read_json(myf)
self.data_type = 'pandas'
else:
with open(self.input_filename, 'r') as myf:
self.input = load(myf)
self.data_type = 'list'
if not isinstance(self.input, list):
raise Exception(
'The JSON file must contain a list for datafuzz use.')
self.records = self.input.copy()
def _read_list(self):
""" Read in list to list or dataframe"""
self.original = self.input
if self.USE_PANDAS:
self.input = pd.DataFrame(self.input)
self.data_type = 'pandas'
else:
self.data_type = 'list'
self.records = self.input.copy()
@property
def input_filename(self):
""" Return filename if input follows proper file format \
file://[absolute or relative filepath]
NOTE: this will raise an exception if the file is not found
"""
filename = re.match(self.FILE_REGEX, self.original).group('filename')
if not os.path.exists(filename):
raise Exception('Could not retrieve filename {}'.format(filename))
return filename
@property
def output_filename(self):
""" Return filename if output follows proper file format \
file://[absolute or relative filepath]
"""
return re.match(self.FILE_REGEX, self.output).group('filename')
[docs] def sample(self, percentage, columns=False):
""" Get a sample from the dataset.
Arguments:
percentage (float): percentage of dataset to sample \
should be a value from 0.0-1.0
Kwargs:
columns (bool): option to sample columns from dataset \
default is False
Returns:
A sample from the dataset with matching datatype
"""
if self.data_type == 'pandas':
if columns:
sample = np.random.choice(
self.records.columns,
round(self.records.shape[1] * percentage), replace=False)
else:
sample = self.records.sample(frac=percentage).copy(deep=True)
elif self.data_type == 'numpy':
if columns:
sample = np.random.choice(
range(len(self.records[0])),
round(len(self.records[0]) * percentage), replace=False)
else:
sample = self.records[
np.random.choice(self.records.shape[0],
round(self.records.shape[0] * percentage),
replace=False)]
else:
if columns:
sample = random.sample(
range(len(self.records[0])),
round(len(self.records[0]) * percentage))
else:
sample = random.sample(self.records,
round(len(self.records) * percentage))
return sample
[docs] def append(self, rows):
""" Append rows to DataSet records
Arguments:
rows (list): rows to add or concatenate
TODO:
- is a shuffle needed?
- should the index be maintained or reordered
- should new indexes be ordered or not
"""
if self.data_type == 'list':
self.records.extend(rows)
elif self.data_type == 'numpy':
self.records = np.append(self.records, rows, axis=0)
else:
if not isinstance(rows, pd.DataFrame):
rows = pd.DataFrame(rows)
self.records = pd.concat([self.records, rows], ignore_index=True)
[docs] def to_output(self):
""" Transform DataSet records to output. \
This uses helper method `obj_to_output` \
located in `output/helpers.py`
Returns output object or filepath.
"""
return obj_to_output(self)
[docs] def column_idx(self, column):
""" Return numeric index of a column
NOTE: if column is not found, raises an AttributeError
"""
if isinstance(column, str) and column.isnumeric():
return int(column)
elif self.data_type == 'pandas':
return self.records.columns.get_loc(column)
elif self.data_type == 'list' and isinstance(self.records[0], dict):
return list(self.records[0].keys()).index(column)
elif isinstance(column, int):
return column
elif not isinstance(column, str):
if HAS_NUMPY and np.issubdtype(column, np.integer):
return column
raise AttributeError('Column {} could not be found!'.format(column))
[docs] def column_dtype(self, column):
""" Return dtype of column
Arguments:
column (int): column index
Return:
data type of the column
TODO:
- determine smart way to test more than one row for a list?
"""
if self.data_type == 'pandas':
return self.records.iloc[:, column].dtype
elif self.data_type == 'numpy':
return self.records[:, column].dtype
elif isinstance(self.records[0], dict):
return type(list(self.records[0].values())[column])
return type(self.records[0][column])
[docs] def column_agg(self, column, agg_func):
""" Perform aggregate function on given column
Arguments:
column (int): column index
agg_func (function): aggregate function to perform on column
Returns aggregate result
Example:
`dataset.column_agg(3, min)`
"""
if self.data_type == 'pandas':
return agg_func(self.records.iloc[:, column])
elif self.data_type == 'numpy':
return agg_func(self.records[:, column])
elif isinstance(self.records[0], dict):
return agg_func([list(x.values())[column] for x in self.records])
return agg_func([x[column] for x in self.records])