From 75c61a8095af75cdeca08ac2d6c4374b45be6ede Mon Sep 17 00:00:00 2001 From: Ricardo Yaben <32867697+RicYaben@users.noreply.github.com> Date: Sun, 10 Jul 2022 11:58:14 +0200 Subject: [PATCH] Proxy for TCP and UDP [#63 #65 #66] --- internal/proxy/middlewares.go | 64 ++++++++++ internal/proxy/proxy.go | 224 ++++++++++++++++++++++++++++++++++ internal/proxy/tcpproxy.go | 153 +++++++++++++++++++++++ internal/proxy/udpproxy.go | 129 ++++++++++++++++++++ 4 files changed, 570 insertions(+) create mode 100644 internal/proxy/middlewares.go create mode 100644 internal/proxy/proxy.go create mode 100644 internal/proxy/tcpproxy.go create mode 100644 internal/proxy/udpproxy.go diff --git a/internal/proxy/middlewares.go b/internal/proxy/middlewares.go new file mode 100644 index 0000000..c1f13a9 --- /dev/null +++ b/internal/proxy/middlewares.go @@ -0,0 +1,64 @@ +package proxy + +import ( + "fmt" + "net" +) + +var ( + // Exportable middlewares manager + Middlewares = NewMiddlewareManager() +) + +// Use this interface to create new middlewares that include the `handle` function +type Middleware interface { + // Handle a connection, do something to it + handle(conn net.Conn) (net.Conn, error) +} + +type MiddlewareManager interface { + // Apply all the registered middlewares having the connection in consideration + Apply(conn net.Conn) (ret net.Conn, err error) + // Register a new middleware + Register(middleware Middleware) (Middleware, error) +} + +type MiddlewareManagerItem struct { + MiddlewareManager + + // List of middlewares + middlewares []Middleware +} + +// Register a middleware +func (mm *MiddlewareManagerItem) Register(middleware Middleware) (mid Middleware, err error) { + // Iterate the registered middlewares + for _, md := range mm.middlewares { + if md == middleware { + err = fmt.Errorf("middleware already registered") + return + } + } + + // Append the middleware to the list of registered middlewares + mm.middlewares = append(mm.middlewares, middleware) + mid = middleware + + return +} + +// Apply each middleware to the connection +func (mm *MiddlewareManagerItem) Apply(conn net.Conn) (ret net.Conn, err error) { + for _, middleware := range mm.middlewares { + ret, err = middleware.handle(conn) + } + + return +} + +func NewMiddlewareManager() *MiddlewareManagerItem { + return &MiddlewareManagerItem{ + // Create a slice of size 0 for the middlewares + middlewares: make([]Middleware, 0), + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 0000000..8191337 --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,224 @@ +package proxy + +import ( + "fmt" + "net" + "sync" + + "github.com/riotpot/pkg/services" +) + +const ( + TCP = "tcp" + UDP = "udp" +) + +// Proxy interface. +type Proxy interface { + // Start proxy function + Start() + + // Stop the proxy + Stop() error + // Check if the proxy is running + Alive() bool + + // Setter and Getter for the port + SetPort(port int) (int, error) + GetPort() int + + // Set the service in the proxy + SetService(service services.Service) +} + +// Abstraction of the proxy endpoint +// Contains private fields, do not use outside of this package +type AbstractProxy struct { + Proxy + + // Port in where the proxy will listen + port int + // Protocol meant for this proxy + protocol string + + // Create a channel to stop the proxy gracefully + // This channel is also used to guess if the proxy is running + stop chan struct{} + + // Pointer to the slice of middlewares for the proxies + // All the proxies should apply and share the same middlewares + // Perhaps this can be changed in the future given the need to apply middlewares per proxy + middlewares *MiddlewareManagerItem + + // Service to proxy + service services.Service + + // Waiting group for the server + wg sync.WaitGroup +} + +// Simple function to check if the proxy is running +func (pe *AbstractProxy) Alive() (alive bool) { + // When the proxy is instantiated, the stop channel is nil; + // therefore, the proxy is not running + if pe.stop == nil { + return + } + + // [7/4/2022] NOTE: The logic of this block is difficult to read. + // However, the select block will only give the default value when there is nothing + // to read from the channel while the channel is still open. + // When the channel is closed, the first case is not blocked, so we can not + // read "anything else" from the channel + select { + // Return if the channel is closed + case <-pe.stop: + // Return if the channel is open + default: + alive = true + } + + return +} + +// Set the port based on some criteria +func (pe *AbstractProxy) SetPort(port int) (p int, err error) { + p = port + // Check if there is a port and is acceptable + if !(port < 65536 && port > 0) { + err = fmt.Errorf("invalid port %d", port) + return + } + + // Check if the port is taken + ln, err := net.Listen(pe.protocol, fmt.Sprintf(":%d", port)) + if err != nil { + return + } + defer ln.Close() + + pe.port = port + return +} + +// Returns the proxy port +func (pe *AbstractProxy) GetPort() int { + return pe.port +} + +func (pe *AbstractProxy) SetService(service services.Service) { + pe.service = service +} + +// Create a new instance of the proxy +func NewProxyEndpoint(port int, protocol string) (pe Proxy, err error) { + // Get the proxy for UDP or TCP + switch protocol { + case TCP: + pe, err = NewTCPProxy(port) + case UDP: + pe, err = NewUDPProxy(port) + } + + return +} + +// Interface for the proxy manager +type ProxyManager interface { + // Create a new proxy and add it to the manager + CreateProxy(port int) (*TCPProxy, error) + // Delete a proxy from the list + DeleteProxy(port int) error + // Get a proxy by the port it uses + GetProxy(port int) (*TCPProxy, error) + // Set the service for a proxy + SetService(port int, service services.Service) (pe *TCPProxy, err error) +} + +// Simple implementation of the proxy manager +// This manager has access to the proxy endpoints registered. However, it does not observe newly +// +type ProxyManagerItem struct { + ProxyManager + + // List of proxy endpoints registered in the manager + proxies []Proxy + + // Instance of the middleware manager + middlewares *MiddlewareManagerItem +} + +func (pm *ProxyManagerItem) CreateProxy(protocol string, port int) (pe Proxy, err error) { + // Check if there is another proxy with the same port + if proxy, _ := pm.GetProxy(port); proxy != nil { + err = fmt.Errorf("proxy already registered") + return + } + + // Create the proxy + pe, err = NewProxyEndpoint(port, protocol) + + // Append the proxy to the list + pm.proxies = append(pm.proxies, pe) + return +} + +// Delete a proxy from the registered list +// The proxy is stopped before being removed +func (pm *ProxyManagerItem) DeleteProxy(port int) (err error) { + // Iterate the registered proxies for the proxy using the given port, and stop and remove it from the slice + for ind, proxy := range pm.proxies { + if proxy.GetPort() == port { + // Stop the proxy, just in case + proxy.Stop() + // Remove it from the slice by replacing it with the last item from the slice, and reducing the slice + // by 1 element + lastInd := len(pm.proxies) - 1 + + pm.proxies[ind] = pm.proxies[lastInd] + pm.proxies = pm.proxies[:lastInd] + return + } + } + + // If the proxy was not foun, send an error + err = fmt.Errorf("proxy not found") + return +} + +// Returns a proxy by the port number +func (pm *ProxyManagerItem) GetProxy(port int) (pe Proxy, err error) { + // Iterate the proxies registered, and if the proxy using the given port is found, return it + for _, proxy := range pm.proxies { + if proxy.GetPort() == port { + pe = proxy + return + } + } + + // If the proxy was not foun, send an error + err = fmt.Errorf("proxy not found") + return +} + +// Set the service for some proxy +func (pm *ProxyManagerItem) SetService(port int, service services.Service) (pe Proxy, err error) { + // Get the proxy from the list + pe, err = pm.GetProxy(port) + if err != nil { + return + } + + // If the proxy was found, set the service + pe.SetService(service) + + return +} + +// Constructor for the proxy manager +func NewProxyManager() *ProxyManagerItem { + return &ProxyManagerItem{ + middlewares: Middlewares, + proxies: make([]Proxy, 0), + } +} diff --git a/internal/proxy/tcpproxy.go b/internal/proxy/tcpproxy.go new file mode 100644 index 0000000..06ed189 --- /dev/null +++ b/internal/proxy/tcpproxy.go @@ -0,0 +1,153 @@ +package proxy + +import ( + "fmt" + "io" + "log" + "net" + "sync" + "time" +) + +// Implementation of a TCP proxy + +type TCPProxy struct { + *AbstractProxy + listener net.Listener +} + +// Start listening for connections +func (tcpProxy *TCPProxy) Start() { + // Get the listener or create a new one + listener := tcpProxy.GetListener() + // Create a channel to stop the proxy + tcpProxy.stop = make(chan struct{}) + + // Add a waiting task + tcpProxy.wg.Add(1) + + go func() { + defer tcpProxy.wg.Done() + + for { + // Accept the next connection + // This goes first as it is the method we have to check if the proxy is running + // There is no need to continue if it is not + client, err := listener.Accept() + if err != nil { + // If the channel was closed, the proxy should stop + if !tcpProxy.Alive() { + return + } + fmt.Println(err) + } + defer client.Close() + + // Get a connection to the server for each new connection with the client + server, servErr := net.DialTimeout(TCP, tcpProxy.service.GetAddress(), 1*time.Second) + + // If there was an error, close the connection to the server and return + if servErr != nil { + server.Close() + return + } + defer server.Close() + + // Add a waiting task + tcpProxy.wg.Add(1) + + go func() { + // Apply the middlewares to the connection + tcpProxy.middlewares.Apply(client) + + // Handle the connection between the client and the server + // NOTE: The handlers will defer the connections + tcpProxy.handle(client, server) + + // Finish the task + tcpProxy.wg.Done() + }() + } + }() +} + +// Function to stop the proxy from runing +func (tcpProxy *TCPProxy) Stop() (err error) { + // Stop the proxy if it is still alive + if tcpProxy.Alive() { + close(tcpProxy.stop) + tcpProxy.listener.Close() + // Wait for all the connections and the server to stop + tcpProxy.wg.Wait() + return + } + + err = fmt.Errorf("proxy not running") + return +} + +// Get or create a new listener +func (tcpProxy *TCPProxy) GetListener() net.Listener { + if tcpProxy.listener == nil || !tcpProxy.Alive() { + listener, err := net.Listen(tcpProxy.protocol, fmt.Sprintf(":%d", tcpProxy.GetPort())) + if err != nil { + log.Fatal(err) + } + tcpProxy.listener = listener + } + return tcpProxy.listener +} + +// TCP synchronous tunnel that forwards requests from source to destination and back +func (tcpProxy *TCPProxy) handle(from net.Conn, to net.Conn) { + // Create the waiting group for the connections so they can answer the each other + var wg sync.WaitGroup + wg.Add(2) + + handler := func(source net.Conn, dest net.Conn) { + defer wg.Done() + + // Write the content from the source to the destination + _, err := io.Copy(dest, source) + if err != nil { + log.Print(err) + } + + // Close the connection to the source + if err := source.Close(); err != nil { + log.Print(err) + } + + // Attempt to close the writter. This may not always work + // Another solution is to just call `Close()` on the writter + if d, ok := dest.(*net.TCPConn); ok { + if err := d.CloseWrite(); err != nil { + log.Print(err) + } + + } + } + + // Start the workers + // TODO: [7/3/2022] Check somewhere if the connection is still alive from the source and destination + // Otherwise there is no need to wait + go handler(from, to) + go handler(to, from) + + // Wait until the forwarding is done + wg.Wait() +} + +func NewTCPProxy(port int) (proxy *TCPProxy, err error) { + // Create a new proxy + proxy = &TCPProxy{ + AbstractProxy: &AbstractProxy{ + middlewares: Middlewares, + protocol: TCP, + }, + } + + // Set the port + _, err = proxy.SetPort(port) + return +} diff --git a/internal/proxy/udpproxy.go b/internal/proxy/udpproxy.go new file mode 100644 index 0000000..f7e46fa --- /dev/null +++ b/internal/proxy/udpproxy.go @@ -0,0 +1,129 @@ +package proxy + +import ( + "fmt" + "log" + "net" + "sync" +) + +type UDPProxy struct { + *AbstractProxy + listener *net.UDPConn +} + +func (udpProxy *UDPProxy) Start() { + // Get the listener or create a new one + client := udpProxy.GetListener() + defer client.Close() + // Create a channel to stop the proxy + udpProxy.stop = make(chan struct{}) + + // Add a waiting task + udpProxy.wg.Add(1) + + srvAddr := net.UDPAddr{ + Port: udpProxy.service.GetPort(), + } + + for { + // Get a connection to the server for each new connection with the client + server, servErr := net.DialUDP(UDP, nil, &srvAddr) + // If there was an error, close the connection to the server and return + if servErr != nil { + server.Close() + return + } + defer server.Close() + + go func() { + // TODO: Handle the middlewares! they only accept TCP connections + // Apply the middlewares to the connection + //udpProxy.middlewares.Apply(listener) + + // Handle the connection between the client and the server + // NOTE: The handlers will defer the connections + udpProxy.handle(client, server) + + // Finish the task + udpProxy.wg.Done() + }() + } +} + +// Function to stop the proxy from runing +func (udpProxy *UDPProxy) Stop() (err error) { + // Stop the proxy if it is still alive + if udpProxy.Alive() { + close(udpProxy.stop) + udpProxy.listener.Close() + // Wait for all the connections and the server to stop + udpProxy.wg.Wait() + return + } + + err = fmt.Errorf("proxy not running") + return +} + +// Get or create a new listener +func (udpProxy *UDPProxy) GetListener() *net.UDPConn { + if udpProxy.listener == nil || !udpProxy.Alive() { + // Get the address of the UDP server + addr := net.UDPAddr{ + Port: udpProxy.service.GetPort(), + } + + listener, err := net.ListenUDP(UDP, &addr) + if err != nil { + log.Fatal(err) + } + udpProxy.listener = listener + } + + return udpProxy.listener +} + +// TODO: Test this function +// UDP asynchronous tunnel +func (udpProxy *UDPProxy) handle(client *net.UDPConn, server *net.UDPConn) { + var buf [2 << 10]byte + var wg sync.WaitGroup + wg.Add(2) + + // Function to copy messages from one pipe to the other + var handle = func(from *net.UDPConn, to *net.UDPConn) { + n, addr, err := from.ReadFrom(buf[0:]) + if err != nil { + log.Print(err) + } + + _, err = to.WriteTo(buf[:n], addr) + if err != nil { + log.Print(err) + } + } + + defer client.Close() + defer server.Close() + + go handle(client, server) + go handle(server, client) + + // Wait until the forwarding is done + wg.Wait() +} + +func NewUDPProxy(port int) (proxy *UDPProxy, err error) { + // Create a new proxy + proxy = &UDPProxy{ + AbstractProxy: &AbstractProxy{ + middlewares: Middlewares, + protocol: UDP, + }, + } + + // Set the port + _, err = proxy.SetPort(port) + return +}