Skip to content

Commit df7db86

Browse files
committed
feat: support TSV files
1 parent e8459fb commit df7db86

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

examples/main/main.cpp

+27-23
Original file line numberDiff line numberDiff line change
@@ -851,41 +851,41 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
851851
return true;
852852
}
853853

854-
std::vector<std::vector<std::string>> read_csv(const std::string &csv_file) {
855-
// Read a csv file and return a vector where each element is a vector of [wav_filename, transcript, id, profile_id]
856-
std::vector<std::vector<std::string>> csv_data;
857-
std::ifstream csv(csv_file);
858-
if (!csv.is_open()) {
859-
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, csv_file.c_str());
860-
return csv_data;
854+
std::vector<std::vector<std::string>> read_delimited_file(const std::string &file_path, char delimiter) {
855+
// Read a delimited file (CSV or TSV) and return a vector where each element is a vector of values
856+
std::vector<std::vector<std::string>> file_data;
857+
std::ifstream file(file_path);
858+
if (!file.is_open()) {
859+
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, file_path.c_str());
860+
return file_data;
861861
}
862862

863863
// Skip the header
864864
std::string header;
865-
std::getline(csv, header);
865+
std::getline(file, header);
866866

867867
// Read each line and extract values
868868
std::string line;
869-
while (std::getline(csv, line)) {
869+
while (std::getline(file, line)) {
870870
std::stringstream ss(line);
871871
std::string wav_filename, id, profile_id, duration, wav_filesize, transcript, language, region, evaluation_value;
872872

873-
// Assuming the columns are in order: wav_filename,id,profile_id,duration,wav_filesize,transcript,language,region,evaluation_value
874-
std::getline(ss, wav_filename, ',');
875-
std::getline(ss, id, ',');
876-
std::getline(ss, profile_id, ',');
877-
std::getline(ss, duration, ',');
878-
std::getline(ss, wav_filesize, ',');
879-
std::getline(ss, transcript, ',');
880-
std::getline(ss, language, ',');
881-
std::getline(ss, region, ',');
882-
std::getline(ss, evaluation_value, ',');
873+
// Extract fields based on the specified delimiter
874+
std::getline(ss, wav_filename, delimiter);
875+
std::getline(ss, id, delimiter);
876+
std::getline(ss, profile_id, delimiter);
877+
std::getline(ss, duration, delimiter);
878+
std::getline(ss, wav_filesize, delimiter);
879+
std::getline(ss, transcript, delimiter);
880+
std::getline(ss, language, delimiter);
881+
std::getline(ss, region, delimiter);
882+
std::getline(ss, evaluation_value, delimiter);
883883

884884
std::vector<std::string> row = {wav_filename, transcript, id, profile_id, "", language, region, evaluation_value};
885-
csv_data.push_back(row);
885+
file_data.push_back(row);
886886
}
887887

888-
return csv_data;
888+
return file_data;
889889
}
890890

891891
std::string remove_leading_trailing_whitespace(const std::string& input) {
@@ -1096,8 +1096,12 @@ int main(int argc, char ** argv) {
10961096
// make a dict for the scores
10971097
std::map<std::string, std::vector<std::vector<std::pair<std::string, float>>> > csv_scores;
10981098
if (params.csv_file != "") {
1099-
fprintf(stderr, "%s: csv file: %s\n", __func__, params.csv_file.c_str());
1100-
csv_data = read_csv(params.csv_file.c_str());
1099+
char delimiter = ',';
1100+
if (params.csv_file.substr(params.csv_file.find_last_of(".") + 1) == "tsv") {
1101+
delimiter = '\t';
1102+
}
1103+
fprintf(stderr, "%s: file: %s\n", __func__, params.csv_file.c_str());
1104+
csv_data = read_delimited_file(params.csv_file.c_str(), delimiter);
11011105
// make a dictionary of wav_filename -> transcript, id, profile_id, result, confidence_score
11021106
for (const auto &element : csv_data) {
11031107
// Make an empty vector to store the result and confidence score

0 commit comments

Comments
 (0)