-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: complete postgres client compatibility tests (#111)
- Loading branch information
1 parent
5ea43ee
commit afddd96
Showing
9 changed files
with
672 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/bash | ||
|
||
rm -rf ./c/pg_test \ | ||
./csharp/bin ./csharp/obj \ | ||
./go/pg \ | ||
./java/*.class \ | ||
./rust/target ./rust/Cargo.lock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Data; | ||
using Microsoft.Data.SqlClient; | ||
using System.IO; | ||
|
||
public class PGTest | ||
{ | ||
public class Tests | ||
{ | ||
private SqlConnection conn; | ||
private SqlCommand cmd; | ||
private List<Test> tests = new List<Test>(); | ||
|
||
public void Connect(string ip, int port, string user, string password) | ||
{ | ||
try | ||
{ | ||
string connectionString = $"Server={ip},{port};User Id={user};Password={password};"; | ||
conn = new SqlConnection(connectionString); | ||
conn.Open(); | ||
cmd = conn.CreateCommand(); | ||
cmd.CommandType = CommandType.Text; | ||
} | ||
catch (SqlException e) | ||
{ | ||
throw new Exception(e.Message); | ||
} | ||
} | ||
|
||
public void Disconnect() | ||
{ | ||
try | ||
{ | ||
cmd.Dispose(); | ||
conn.Close(); | ||
} | ||
catch (SqlException e) | ||
{ | ||
throw new Exception(e.Message); | ||
} | ||
} | ||
|
||
public void AddTest(string query, string[][] expectedResults) | ||
{ | ||
tests.Add(new Test(query, expectedResults)); | ||
} | ||
|
||
public bool RunTests() | ||
{ | ||
foreach (var test in tests) | ||
{ | ||
if (!test.Run(cmd)) | ||
{ | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
public void ReadTestsFromFile(string filename) | ||
{ | ||
try | ||
{ | ||
using (var reader = new StreamReader(filename)) | ||
{ | ||
string line; | ||
while ((line = reader.ReadLine()) != null) | ||
{ | ||
if (string.IsNullOrWhiteSpace(line)) continue; | ||
string query = line; | ||
var results = new List<string[]>(); | ||
while ((line = reader.ReadLine()) != null && !string.IsNullOrWhiteSpace(line)) | ||
{ | ||
results.Add(line.Split(',')); | ||
} | ||
string[][] expectedResults = results.ToArray(); | ||
AddTest(query, expectedResults); | ||
} | ||
} | ||
} | ||
catch (IOException e) | ||
{ | ||
Console.Error.WriteLine(e.Message); | ||
Environment.Exit(1); | ||
} | ||
} | ||
|
||
public class Test | ||
{ | ||
private string query; | ||
private string[][] expectedResults; | ||
|
||
public Test(string query, string[][] expectedResults) | ||
{ | ||
this.query = query; | ||
this.expectedResults = expectedResults; | ||
} | ||
|
||
public bool Run(SqlCommand cmd) | ||
{ | ||
try | ||
{ | ||
Console.WriteLine("Running test: " + query); | ||
cmd.CommandText = query; | ||
using (var reader = cmd.ExecuteReader()) | ||
{ | ||
if (!reader.HasRows) | ||
{ | ||
Console.WriteLine("Returns 0 rows"); | ||
return expectedResults.Length == 0; | ||
} | ||
if (reader.FieldCount != expectedResults[0].Length) | ||
{ | ||
Console.Error.WriteLine($"Expected {expectedResults[0].Length} columns, got {reader.FieldCount}"); | ||
return false; | ||
} | ||
int rows = 0; | ||
while (reader.Read()) | ||
{ | ||
for (int col = 0; col < expectedResults[rows].Length; col++) | ||
{ | ||
string result = reader.GetString(col); | ||
if (expectedResults[rows][col] != result) | ||
{ | ||
Console.Error.WriteLine($"Expected:\n'{expectedResults[rows][col]}'"); | ||
Console.Error.WriteLine($"Result:\n'{result}'\nRest of the results:"); | ||
while (reader.Read()) | ||
{ | ||
Console.Error.WriteLine(reader.GetString(0)); | ||
} | ||
return false; | ||
} | ||
} | ||
rows++; | ||
} | ||
Console.WriteLine("Returns " + rows + " rows"); | ||
if (rows != expectedResults.Length) | ||
{ | ||
Console.Error.WriteLine($"Expected {expectedResults.Length} rows"); | ||
return false; | ||
} | ||
return true; | ||
} | ||
} | ||
catch (SqlException e) | ||
{ | ||
Console.Error.WriteLine(e.Message); | ||
return false; | ||
} | ||
} | ||
} | ||
} | ||
|
||
public static void Main(string[] args) | ||
{ | ||
if (args.Length < 5) | ||
{ | ||
Console.Error.WriteLine("Usage: PGTest <ip> <port> <user> <password> <testFile>"); | ||
Environment.Exit(1); | ||
} | ||
|
||
var tests = new Tests(); | ||
tests.Connect(args[0], int.Parse(args[1]), args[2], args[3]); | ||
tests.ReadTestsFromFile(args[4]); | ||
|
||
if (!tests.RunTests()) | ||
{ | ||
tests.Disconnect(); | ||
Environment.Exit(1); | ||
} | ||
tests.Disconnect(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<OutputType>Exe</OutputType> | ||
<TargetFramework>net8.0</TargetFramework> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.Data.SqlClient" Version="5.2" /> | ||
</ItemGroup> | ||
|
||
</Project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
package main | ||
|
||
import ( | ||
"bufio" | ||
"database/sql" | ||
"fmt" | ||
"os" | ||
"strconv" | ||
"strings" | ||
|
||
_ "github.com/lib/pq" | ||
) | ||
|
||
type Test struct { | ||
query string | ||
expectedResults [][]string | ||
} | ||
|
||
type Tests struct { | ||
conn *sql.DB | ||
tests []Test | ||
} | ||
|
||
func (t *Tests) connect(ip string, port int, user, password string) { | ||
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=postgres sslmode=disable", ip, port, user, password) | ||
var err error | ||
t.conn, err = sql.Open("postgres", connStr) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
func (t *Tests) disconnect() { | ||
err := t.conn.Close() | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
func (t *Tests) addTest(query string, expectedResults [][]string) { | ||
t.tests = append(t.tests, Test{query, expectedResults}) | ||
} | ||
|
||
func (t *Tests) readTestsFromFile(filename string) { | ||
file, err := os.Open(filename) | ||
if err != nil { | ||
panic(err) | ||
} | ||
defer func(file *os.File) { | ||
err := file.Close() | ||
if err != nil { | ||
panic(err) | ||
} | ||
}(file) | ||
|
||
scanner := bufio.NewScanner(file) | ||
var query string | ||
var results [][]string | ||
for scanner.Scan() { | ||
line := strings.TrimSpace(scanner.Text()) | ||
if line == "" { | ||
if query != "" { | ||
t.addTest(query, results) | ||
query = "" | ||
results = nil | ||
} | ||
} else if query == "" { | ||
query = line | ||
} else { | ||
results = append(results, strings.Split(line, ",")) | ||
} | ||
} | ||
if query != "" { | ||
t.addTest(query, results) | ||
} | ||
if err := scanner.Err(); err != nil { | ||
panic(err) | ||
} | ||
} | ||
|
||
func (t *Tests) runTests() bool { | ||
for _, test := range t.tests { | ||
if !t.runTest(test) { | ||
return false | ||
} | ||
} | ||
return true | ||
} | ||
|
||
func (t *Tests) runTest(test Test) bool { | ||
fmt.Println("Running test:", test.query) | ||
rows, err := t.conn.Query(test.query) | ||
if err != nil { | ||
fmt.Println("Error executing query:", err) | ||
return false | ||
} | ||
defer func(rows *sql.Rows) { | ||
err := rows.Close() | ||
if err != nil { | ||
panic(err) | ||
} | ||
}(rows) | ||
|
||
columns, err := rows.Columns() | ||
if err != nil { | ||
fmt.Println("Error getting columns:", err) | ||
return false | ||
} | ||
|
||
if len(test.expectedResults) == 0 { | ||
fmt.Println("Returns 0 rows") | ||
return len(columns) == 0 | ||
} | ||
|
||
if len(columns) != len(test.expectedResults[0]) { | ||
fmt.Printf("Expected %d columns, got %d\n", len(test.expectedResults[0]), len(columns)) | ||
return false | ||
} | ||
|
||
var rowCount int | ||
for rows.Next() { | ||
row := make([]string, len(columns)) | ||
rowPointers := make([]interface{}, len(columns)) | ||
for i := range row { | ||
rowPointers[i] = &row[i] | ||
} | ||
if err := rows.Scan(rowPointers...); err != nil { | ||
fmt.Println("Error scanning row:", err) | ||
return false | ||
} | ||
|
||
for i, expected := range test.expectedResults[rowCount] { | ||
if row[i] != expected { | ||
fmt.Printf("Expected: '%s', got: '%s'\n", expected, row[i]) | ||
return false | ||
} | ||
} | ||
rowCount++ | ||
} | ||
|
||
if rowCount != len(test.expectedResults) { | ||
fmt.Printf("Expected %d rows, got %d\n", len(test.expectedResults), rowCount) | ||
return false | ||
} | ||
|
||
fmt.Printf("Returns %d rows\n", rowCount) | ||
return true | ||
} | ||
|
||
func main() { | ||
if len(os.Args) < 6 { | ||
fmt.Println("Usage: pg_test <ip> <port> <user> <password> <testFile>") | ||
os.Exit(1) | ||
} | ||
|
||
ip := os.Args[1] | ||
port, ok := strconv.Atoi(os.Args[2]) | ||
if ok != nil { | ||
fmt.Println("Invalid port:", os.Args[2]) | ||
os.Exit(1) | ||
} | ||
user := os.Args[3] | ||
password := os.Args[4] | ||
testFile := os.Args[5] | ||
|
||
tests := &Tests{} | ||
tests.connect(ip, port, user, password) | ||
tests.readTestsFromFile(testFile) | ||
|
||
if !tests.runTests() { | ||
tests.disconnect() | ||
os.Exit(1) | ||
} | ||
tests.disconnect() | ||
} |
Oops, something went wrong.