-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_infer.sh
79 lines (68 loc) · 2.34 KB
/
run_infer.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/bin/bash
[ -z "$1" ] && echo "First argument is the dataset name." && exit 1
DATASET_NAME="$1"
[ -z "$2" ] && echo "Second argument is the NL2SQL model name." && exit 1
MODEL_NAME="$2"
[ -z "$3" ] && echo "Third argument is the test JSON file." && exit 1
TEST_FILE="$3"
[ -z "$4" ] && echo "Fourth argument is the raw beam output txt file" && exit 1
RAW_BEAM_OUTPUT_FILE="$4"
[ -z "$5" ] && echo "Fifth argument is the datset table schema file." && exit 1
TABLES_FILE="$5"
[ -z "$6" ] && echo "Sixth argument is the directory of the databases of the dataset." && exit 1
DB_DIR="$6"
[ -z "$7" ] && echo "Seventh argument is the directory of the test suite databases of the dataset (optional)." && exit 1
TS_DB_DIR="$7"
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
echo "===================================================================================================================================="
echo "INFO ****** CycleSQL Inference Pipeline Start ******"
# Define some variables
OUTPUT_DIR=outputs/$DATASET_NAME/$MODEL_NAME
if [ ! -d $OUTPUT_DIR ]; then
mkdir -p $OUTPUT_DIR
fi
NLI_MODEL_DIR=saved_models/checkpoint-500
case $MODEL_NAME in
"chatgpt")
BEAM_SIZE=5
;;
"gpt-4")
BEAM_SIZE=5
;;
"chess")
BEAM_SIZE=5
;;
"dailsql")
BEAM_SIZE=8
;;
"smbop")
BEAM_SIZE=8
;;
"picard")
BEAM_SIZE=8
;;
"resdsql")
BEAM_SIZE=6
;;
"resdsql-3b")
BEAM_SIZE=8
;;
*)
echo "unknown NL2SQL model!"
exit;
;;
esac
OUTPUT_FILE=$OUTPUT_DIR/preds.txt
if [ ! -f $OUTPUT_FILE ]; then
python -m scripts.run_infer --model_name $MODEL_NAME --beam_size $BEAM_SIZE --test_file $TEST_FILE \
--beam_output_file $RAW_BEAM_OUTPUT_FILE --nli_model_dir $NLI_MODEL_DIR \
--table_file_path $TABLES_FILE --db_dir $DB_DIR --output_file_path $OUTPUT_FILE || exit $?
else
echo "WARNING \`$OUTPUT_FILE\` already exists."
fi
# Final Evaluation
EVALUATE_OUTPUT_FILE=$OUTPUT_DIR/eval_result.txt
python -m spider_utils.eval.evaluation --gold $TEST_FILE --pred "$OUTPUT_FILE" \
--db "$DB_DIR" --ts_db "$TS_DB_DIR" > "$EVALUATE_OUTPUT_FILE"
echo "Spider evaluation complete! Results are saved in \`$EVALUATE_OUTPUT_FILE\`"
echo "===================================================================================================================================="