Skip to content

Commit

Permalink
fix: fix split_data in csvts connector
Browse files Browse the repository at this point in the history
chore: in-code function documentation for csv and csvts connectors
  • Loading branch information
beniz authored and sileht committed Oct 9, 2020
1 parent f46a710 commit 8f554b5
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 35 deletions.
5 changes: 1 addition & 4 deletions src/csvinputfileconn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ namespace dd
{
// not a number, skip for now
if (column_id == col) // if id is string, replace with number /
// TODO: better scheme
vals.push_back(c);
else
{
Expand Down Expand Up @@ -351,7 +350,7 @@ namespace dd
}

void CSVInputFileConn::read_csv(const std::string &fname,
const bool forbid_shuffle)
const bool &forbid_shuffle)
{
std::ifstream csv_file(fname, std::ios::binary);
_logger->info("fname={} / open={}", fname, csv_file.is_open());
Expand Down Expand Up @@ -434,10 +433,8 @@ namespace dd
}
if (!_id.empty())
add_test_csvline(cid, vals);
//_csvdata_test.emplace_back(cid,vals);
else
add_test_csvline(std::to_string(nlines), vals);
//_csvdata_test.emplace_back(std::to_string(nlines),vals);

// debug
/*std::cout << "csv test data line=";
Expand Down
170 changes: 155 additions & 15 deletions src/csvinputfileconn.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ namespace dd
{
class CSVInputFileConn;

/**
* \brief fetched data element for CSV inputs
*/
class DDCsv
{
public:
Expand All @@ -57,6 +60,9 @@ namespace dd
std::shared_ptr<spdlog::logger> _logger;
};

/**
* \brief In-memory CSV data line holder
*/
class CSVline
{
public:
Expand All @@ -71,6 +77,11 @@ namespace dd
std::vector<double> _v; /**< csv line data */
};

/**
* \brief Categorical values mapper.
* Categorical values are discrete sets that are converted to int
* This class builds and holds the mapper from value to int
*/
class CCategorical
{
public:
Expand All @@ -81,18 +92,32 @@ namespace dd
{
}

/**
* \brief adds a categorical value and its position
* @param v is the categorical value
* @param val is the categorical value position in the discrete set
*/
void add_cat(const std::string &v, const int &val)
{
std::unordered_map<std::string, int>::iterator hit;
if ((hit = _vals.find(v)) == _vals.end())
_vals.insert(std::pair<std::string, int>(v, val));
}

/**
* \brief adds categorical value at the last position in the discrete set
* @param v is the categorical value
*/
void add_cat(const std::string &v)
{
add_cat(v, _vals.size());
}

/**
* \brief gets the discrete value for a categorical value
* @param v is the categorical value
* @return the discrete position value
*/
int get_cat_num(const std::string &v) const
{
std::unordered_map<std::string, int>::const_iterator hit;
Expand All @@ -105,6 +130,9 @@ namespace dd
_vals; /**< categorical value mapping. */
};

/**
* \brief Generic CSV data input connector
*/
class CSVInputFileConn : public InputConnectorStrategy
{
public:
Expand Down Expand Up @@ -200,7 +228,6 @@ namespace dd
_label_set.insert(std::pair<std::string, int>(_label.at(l), l));
}
}
// TODO: array
if (ad_input.has("label_offset"))
{
try
Expand Down Expand Up @@ -237,6 +264,11 @@ namespace dd
this->set_timeout(ad_input);
}

/**
* \brief reads a categorical value mapping from inputs
* this most often applies when the mapping is provided at inference
* time.
*/
void read_categoricals(const APIData &ad_input)
{
if (ad_input.has("categoricals_mapping"))
Expand All @@ -258,6 +290,10 @@ namespace dd
}
}

/**
* \brief scales a vector of double based on min/max bounds
* @param vals the vector with values to be scaled
*/
void scale_vals(std::vector<double> &vals)
{
auto lit = _columns.begin();
Expand Down Expand Up @@ -297,6 +333,11 @@ namespace dd
}
}

/**
* \brief read min/max bounds for scaling input data
* sets _scale flag and _min_vals, _max_vals vectors
* @param ad_input the APIData input object
*/
void read_scale_vals(const APIData &ad_input)
{
if (ad_input.has("scale") && ad_input.get("scale").get<bool>())
Expand Down Expand Up @@ -345,12 +386,23 @@ namespace dd
}
}

/**
* \brief shuffle CSV data vector if shuffle flag is true
* @param csvdata CSV data line vector to be shuffled
*/
void shuffle_data(std::vector<CSVline> &csvdata)
{
if (_shuffle)
std::shuffle(csvdata.begin(), csvdata.end(), _g);
}

/**
* \brief uses _test_split value to split the input dataset
* @param csvdata is the full CSV dataset holder, in output reduced to size
* 1-_test_split
* @param csvdata_test is the test dataset sink, in otput of size
* _test_split
*/
void split_data(std::vector<CSVline> &csvdata,
std::vector<CSVline> &csvdata_test)
{
Expand All @@ -376,18 +428,32 @@ namespace dd
}
}

/**
* \brief adds a CSV data value line to the training set
* @param id
* @param vals
*/
virtual void add_train_csvline(const std::string &id,
std::vector<double> &vals)
{
_csvdata.emplace_back(id, std::move(vals));
}

/**
* \brief adds a CSV data value line to the test set
* @param id
* @param vals
*/
virtual void add_test_csvline(const std::string &id,
std::vector<double> &vals)
{
_csvdata_test.emplace_back(id, std::move(vals));
}

/**
* \brief input data transforms
* @param ad APIData input object
*/
void transform(const APIData &ad)
{
get_data(ad);
Expand Down Expand Up @@ -480,15 +546,39 @@ namespace dd
throw InputConnectorBadParamException("no data could be found");
}

/**
* \brief parse CSV header and sets the reference CSV columns
* @param hline header line as string
*/
void read_header(std::string &hline);

/**
* \brief reads a full CSV dataset and builds the categorical variables and
* values mapper
* @param csv_file input stream for the CSV data file
*/
void fillup_categoricals(std::ifstream &csv_file);

/**
* \brief reads a CSV data line, fills up values and categorical variables
* as one-hot-vectors
* @param hline CSV data line
* @param delim CSV column delimiter
* @param vals vector to be filled up with CSV data values
* @param column_id stores the column that holds the line id
* @param nlines current line counter
*/
void read_csv_line(const std::string &hline, const std::string &delim,
std::vector<double> &vals, std::string &column_id,
int &nlines);

void read_csv(const std::string &fname, const bool forbid_shuffle = false);
/**
* \brief reads a full CSV data file, calls read_csv_line
* @param fname the CSV file name
* @param forbid_shuffle whether shuffle is forbidden
*/
void read_csv(const std::string &fname,
const bool &forbid_shuffle = false);

int batch_size() const
{
Expand All @@ -508,6 +598,10 @@ namespace dd
return _columns.size() - _label.size(); // minus label
}

/**
* \brief fills out response params from input connector values
* @param out APIData that holds the output values
*/
void response_params(APIData &out)
{
APIData adparams;
Expand Down Expand Up @@ -552,6 +646,11 @@ namespace dd
out.add("parameters", adparams);
}

/**
* \brief tests whether a CSV column holds a categorical variable
* @param c the CSV column
* @return true if category, false otherwise
*/
bool is_category(const std::string &c)
{
std::unordered_map<std::string, CCategorical>::const_iterator hit;
Expand All @@ -560,11 +659,25 @@ namespace dd
return false;
}

/**
* \brief adds a value to a categorical variable mapping, modifies
* _categoricals
* @param c the variable name (column)
* @param v the new categorical value to be added
*/
void update_category(const std::string &c, const std::string &val);

/**
* \brief update data columns with one-hot columns introduced to translate
* categorical variables
*/
void update_columns();

// below helpers for csvts
/**
* \brief returns min/max variable values across a CSV dataset
* @param fname CSV filename
* @return pair of vectors for min/max values
*/
std::pair<std::vector<double>, std::vector<double>>
get_min_max_vals(std::string &fname)
{
Expand All @@ -573,6 +686,10 @@ namespace dd
return get_min_max_vals();
}

/**
* \brief finds min/max variable values across a CSV dataset
* @param fname CSV filename
*/
void find_min_max(std::string &fname)
{
std::ifstream csv_file(fname, std::ios::binary);
Expand All @@ -588,18 +705,38 @@ namespace dd
else
find_min_max(csv_file);
}

/**
* \brief finds min/max variable values across a CSV dataset
* @param csv_file CSV file stream
*/
void find_min_max(std::ifstream &csv_file);

/**
* \brief removes min/max values for the CSV dataset variables
*/
void clear_min_max()
{
_min_vals.clear();
_max_vals.clear();
}

/**
* \brief get pre-obtained min/max variable values
* @return pair of vector of min/max variable values
*/
std::pair<std::vector<double>, std::vector<double>> get_min_max_vals()
{
return std::pair<std::vector<double>, std::vector<double>>(_min_vals,
_max_vals);
}

/**
* \brief returns a one-hot-vector of a given size and index
* @param cnum the index of the positive one-hot
* @param size the size of the vector
* @return the one hot vector of double
*/
std::vector<double> one_hot_vector(const int &cnum, const int &size)
{
std::vector<double> v(size, 0.0);
Expand All @@ -610,32 +747,35 @@ namespace dd
// options
bool _shuffle = false;
std::mt19937 _g;
std::string _csv_fname;
std::string _csv_test_fname;
std::list<std::string> _columns;
std::vector<std::string> _label;
std::string _csv_fname; /**< csv main filename. */
std::string _csv_test_fname; /**< csv test filename (optional). */
std::list<std::string> _columns; /**< list of csv columns. */
std::vector<std::string> _label; /**< list of label columns. */
std::unordered_map<std::string, int> _label_set;
std::string _delim = ",";
int _id_pos = -1;
std::vector<int> _label_pos;
std::vector<int> _label_pos; /**< column positions of the labels. */
std::vector<int> _label_offset; /**< negative offset so that labels range
from 0 onward */
std::unordered_set<std::string> _ignored_columns;
std::unordered_set<int> _ignored_columns_pos;
std::unordered_set<std::string>
_ignored_columns; /**< set of ignored columns. */
std::unordered_set<int>
_ignored_columns_pos; /**< ignored columns indexes. */
std::string _id;
bool _scale = false; /**< whether to scale all data between 0 and 1 */
bool _dont_scale_labels
= true; // original csv input conn do not scale labels, while it is
= true; // original csv input conn does not scale labels, while it is
// needed for csv timeseries
bool _scale_between_minus1_and_1 = false;
bool _scale_between_minus1_and_1
= false; /**< whether to scale within [-1,1]. */
std::vector<double>
_min_vals; /**< upper bound used for auto-scaling data */
std::vector<double>
_max_vals; /**< lower bound used for auto-scaling data */
std::unordered_map<std::string, CCategorical>
_categoricals; /**< auto-converted categorical variables */
double _test_split = -1;
int _detect_cols = -1;
_categoricals; /**< auto-converted categorical variables */
double _test_split = -1; /**< dataset test split ratio (optional). */
int _detect_cols = -1; /**< number of detected csv columns. */

// data
std::vector<CSVline> _csvdata;
Expand Down
Loading

0 comments on commit 8f554b5

Please sign in to comment.