@@ -6,6 +6,7 @@ package middleware
66import (
77 "bytes"
88 "context"
9+ "crypto/tls"
910 "errors"
1011 "fmt"
1112 "io"
@@ -20,6 +21,7 @@ import (
2021
2122 "github.com/labstack/echo/v4"
2223 "github.com/stretchr/testify/assert"
24+ "golang.org/x/net/websocket"
2325)
2426
2527// Assert expected with url.EscapedPath method to obtain the path.
@@ -810,3 +812,231 @@ func TestModifyResponseUseContext(t *testing.T) {
810812 assert .Equal (t , "OK" , rec .Body .String ())
811813 assert .Equal (t , "CUSTOM_BALANCER" , rec .Header ().Get ("FROM_BALANCER" ))
812814}
815+
816+ func createSimpleWebSocketServer (serveTLS bool ) * httptest.Server {
817+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
818+ wsHandler := func (conn * websocket.Conn ) {
819+ defer conn .Close ()
820+ for {
821+ var msg string
822+ err := websocket .Message .Receive (conn , & msg )
823+ if err != nil {
824+ return
825+ }
826+ // message back to the client
827+ websocket .Message .Send (conn , msg )
828+ }
829+ }
830+ websocket.Server {Handler : wsHandler }.ServeHTTP (w , r )
831+ })
832+ if serveTLS {
833+ return httptest .NewTLSServer (handler )
834+ }
835+ return httptest .NewServer (handler )
836+ }
837+
838+ func createSimpleProxyServer (t * testing.T , srv * httptest.Server , serveTLS bool , toTLS bool ) * httptest.Server {
839+ e := echo .New ()
840+
841+ if toTLS {
842+ // proxy to tls target
843+ tgtURL , _ := url .Parse (srv .URL )
844+ tgtURL .Scheme = "wss"
845+ balancer := NewRandomBalancer ([]* ProxyTarget {{URL : tgtURL }})
846+
847+ defaultTransport , ok := http .DefaultTransport .(* http.Transport )
848+ if ! ok {
849+ t .Fatal ("Default transport is not of type *http.Transport" )
850+ }
851+ transport := defaultTransport .Clone ()
852+ transport .TLSClientConfig = & tls.Config {
853+ InsecureSkipVerify : true ,
854+ }
855+ e .Use (ProxyWithConfig (ProxyConfig {Balancer : balancer , Transport : transport }))
856+ } else {
857+ // proxy to non-TLS target
858+ tgtURL , _ := url .Parse (srv .URL )
859+ balancer := NewRandomBalancer ([]* ProxyTarget {{URL : tgtURL }})
860+ e .Use (ProxyWithConfig (ProxyConfig {Balancer : balancer }))
861+ }
862+
863+ if serveTLS {
864+ // serve proxy server with TLS
865+ ts := httptest .NewTLSServer (e )
866+ return ts
867+ }
868+ // serve proxy server without TLS
869+ ts := httptest .NewServer (e )
870+ return ts
871+ }
872+
873+ // TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
874+ func TestProxyWithConfigWebSocketNonTLS2NonTLS (t * testing.T ) {
875+ /*
876+ Arrange
877+ */
878+ // Create a WebSocket test server (non-TLS)
879+ srv := createSimpleWebSocketServer (false )
880+ defer srv .Close ()
881+
882+ // create proxy server (non-TLS to non-TLS)
883+ ts := createSimpleProxyServer (t , srv , false , false )
884+ defer ts .Close ()
885+
886+ tsURL , _ := url .Parse (ts .URL )
887+ tsURL .Scheme = "ws"
888+ tsURL .Path = "/"
889+
890+ /*
891+ Act
892+ */
893+
894+ // Connect to the proxy WebSocket
895+ wsConn , err := websocket .Dial (tsURL .String (), "" , "http://localhost/" )
896+ assert .NoError (t , err )
897+ defer wsConn .Close ()
898+
899+ // Send message
900+ sendMsg := "Hello, Non TLS WebSocket!"
901+ err = websocket .Message .Send (wsConn , sendMsg )
902+ assert .NoError (t , err )
903+
904+ /*
905+ Assert
906+ */
907+ // Read response
908+ var recvMsg string
909+ err = websocket .Message .Receive (wsConn , & recvMsg )
910+ assert .NoError (t , err )
911+ assert .Equal (t , sendMsg , recvMsg )
912+ }
913+
914+ // TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
915+ func TestProxyWithConfigWebSocketTLS2TLS (t * testing.T ) {
916+ /*
917+ Arrange
918+ */
919+ // Create a WebSocket test server (TLS)
920+ srv := createSimpleWebSocketServer (true )
921+ defer srv .Close ()
922+
923+ // create proxy server (TLS to TLS)
924+ ts := createSimpleProxyServer (t , srv , true , true )
925+ defer ts .Close ()
926+
927+ tsURL , _ := url .Parse (ts .URL )
928+ tsURL .Scheme = "wss"
929+ tsURL .Path = "/"
930+
931+ /*
932+ Act
933+ */
934+ origin , err := url .Parse (ts .URL )
935+ assert .NoError (t , err )
936+ config := & websocket.Config {
937+ Location : tsURL ,
938+ Origin : origin ,
939+ TlsConfig : & tls.Config {InsecureSkipVerify : true }, // skip verify for testing
940+ Version : websocket .ProtocolVersionHybi13 ,
941+ }
942+ wsConn , err := websocket .DialConfig (config )
943+ assert .NoError (t , err )
944+ defer wsConn .Close ()
945+
946+ // Send message
947+ sendMsg := "Hello, TLS to TLS WebSocket!"
948+ err = websocket .Message .Send (wsConn , sendMsg )
949+ assert .NoError (t , err )
950+
951+ // Read response
952+ var recvMsg string
953+ err = websocket .Message .Receive (wsConn , & recvMsg )
954+ assert .NoError (t , err )
955+ assert .Equal (t , sendMsg , recvMsg )
956+ }
957+
958+ // TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
959+ func TestProxyWithConfigWebSocketNonTLS2TLS (t * testing.T ) {
960+ /*
961+ Arrange
962+ */
963+
964+ // Create a WebSocket test server (TLS)
965+ srv := createSimpleWebSocketServer (true )
966+ defer srv .Close ()
967+
968+ // create proxy server (Non-TLS to TLS)
969+ ts := createSimpleProxyServer (t , srv , false , true )
970+ defer ts .Close ()
971+
972+ tsURL , _ := url .Parse (ts .URL )
973+ tsURL .Scheme = "ws"
974+ tsURL .Path = "/"
975+
976+ /*
977+ Act
978+ */
979+ // Connect to the proxy WebSocket
980+ wsConn , err := websocket .Dial (tsURL .String (), "" , "http://localhost/" )
981+ assert .NoError (t , err )
982+ defer wsConn .Close ()
983+
984+ // Send message
985+ sendMsg := "Hello, Non TLS to TLS WebSocket!"
986+ err = websocket .Message .Send (wsConn , sendMsg )
987+ assert .NoError (t , err )
988+
989+ /*
990+ Assert
991+ */
992+ // Read response
993+ var recvMsg string
994+ err = websocket .Message .Receive (wsConn , & recvMsg )
995+ assert .NoError (t , err )
996+ assert .Equal (t , sendMsg , recvMsg )
997+ }
998+
999+ // TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
1000+ func TestProxyWithConfigWebSocketTLS2NonTLS (t * testing.T ) {
1001+ /*
1002+ Arrange
1003+ */
1004+
1005+ // Create a WebSocket test server (non-TLS)
1006+ srv := createSimpleWebSocketServer (false )
1007+ defer srv .Close ()
1008+
1009+ // create proxy server (TLS to non-TLS)
1010+ ts := createSimpleProxyServer (t , srv , true , false )
1011+ defer ts .Close ()
1012+
1013+ tsURL , _ := url .Parse (ts .URL )
1014+ tsURL .Scheme = "wss"
1015+ tsURL .Path = "/"
1016+
1017+ /*
1018+ Act
1019+ */
1020+ origin , err := url .Parse (ts .URL )
1021+ assert .NoError (t , err )
1022+ config := & websocket.Config {
1023+ Location : tsURL ,
1024+ Origin : origin ,
1025+ TlsConfig : & tls.Config {InsecureSkipVerify : true }, // skip verify for testing
1026+ Version : websocket .ProtocolVersionHybi13 ,
1027+ }
1028+ wsConn , err := websocket .DialConfig (config )
1029+ assert .NoError (t , err )
1030+ defer wsConn .Close ()
1031+
1032+ // Send message
1033+ sendMsg := "Hello, TLS to NoneTLS WebSocket!"
1034+ err = websocket .Message .Send (wsConn , sendMsg )
1035+ assert .NoError (t , err )
1036+
1037+ // Read response
1038+ var recvMsg string
1039+ err = websocket .Message .Receive (wsConn , & recvMsg )
1040+ assert .NoError (t , err )
1041+ assert .Equal (t , sendMsg , recvMsg )
1042+ }
0 commit comments