@@ -18,14 +18,18 @@ var (
1818 aLongTimeAgo = time .Unix (1 , 0 )
1919)
2020
21- func (d * Dialer ) connect (ctx context.Context , c net.Conn , address string ) (_ net.Addr , ctxErr error ) {
22- host , port , err := splitHostPort (address )
21+ func (d * Dialer ) connect (ctx context.Context , c net.Conn , req Request ) (conn net.Conn , _ net.Addr , ctxErr error ) {
22+ var udpHeader []byte
23+
24+ host , port , err := splitHostPort (req .DstAddress )
2325 if err != nil {
24- return nil , err
26+ return c , nil , err
2527 }
2628 if deadline , ok := ctx .Deadline (); ok && ! deadline .IsZero () {
2729 c .SetDeadline (deadline )
28- defer c .SetDeadline (noDeadline )
30+ if req .Cmd != CmdUDPAssociate {
31+ defer c .SetDeadline (noDeadline )
32+ }
2933 }
3034 if ctx != context .Background () {
3135 errCh := make (chan error , 1 )
@@ -47,14 +51,15 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
4751 }()
4852 }
4953
54+ conn = c
5055 b := make ([]byte , 0 , 6 + len (host )) // the size here is just an estimate
5156 b = append (b , Version5 )
5257 if len (d .AuthMethods ) == 0 || d .Authenticate == nil {
5358 b = append (b , 1 , byte (AuthMethodNotRequired ))
5459 } else {
5560 ams := d .AuthMethods
5661 if len (ams ) > 255 {
57- return nil , errors .New ("too many authentication methods" )
62+ return c , nil , errors .New ("too many authentication methods" )
5863 }
5964 b = append (b , byte (len (ams )))
6065 for _ , am := range ams {
@@ -69,11 +74,11 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
6974 return
7075 }
7176 if b [0 ] != Version5 {
72- return nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
77+ return c , nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
7378 }
7479 am := AuthMethod (b [1 ])
7580 if am == AuthMethodNoAcceptableMethods {
76- return nil , errors .New ("no acceptable authentication methods" )
81+ return c , nil , errors .New ("no acceptable authentication methods" )
7782 }
7883 if d .Authenticate != nil {
7984 if ctxErr = d .Authenticate (ctx , c , am ); ctxErr != nil {
@@ -82,7 +87,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
8287 }
8388
8489 b = b [:0 ]
85- b = append (b , Version5 , byte (d . cmd ), 0 )
90+ b = append (b , Version5 , byte (req . Cmd ), 0 )
8691 if ip := net .ParseIP (host ); ip != nil {
8792 if ip4 := ip .To4 (); ip4 != nil {
8893 b = append (b , AddrTypeIPv4 )
@@ -91,17 +96,23 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
9196 b = append (b , AddrTypeIPv6 )
9297 b = append (b , ip6 ... )
9398 } else {
94- return nil , errors .New ("unknown address type" )
99+ return c , nil , errors .New ("unknown address type" )
95100 }
96101 } else {
97102 if len (host ) > 255 {
98- return nil , errors .New ("FQDN too long" )
103+ return c , nil , errors .New ("FQDN too long" )
99104 }
100105 b = append (b , AddrTypeFQDN )
101106 b = append (b , byte (len (host )))
102107 b = append (b , host ... )
103108 }
104109 b = append (b , byte (port >> 8 ), byte (port ))
110+
111+ if req .Cmd == CmdUDPAssociate {
112+ udpHeader = make ([]byte , len (b ))
113+ copy (udpHeader [3 :], b [3 :])
114+ }
115+
105116 if _ , ctxErr = c .Write (b ); ctxErr != nil {
106117 return
107118 }
@@ -110,17 +121,18 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
110121 return
111122 }
112123 if b [0 ] != Version5 {
113- return nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
124+ return c , nil , errors .New ("unexpected protocol version " + strconv .Itoa (int (b [0 ])))
114125 }
115126 if cmdErr := Reply (b [1 ]); cmdErr != StatusSucceeded {
116- return nil , errors .New ("unknown error " + cmdErr .String ())
127+ return c , nil , errors .New ("unknown error " + cmdErr .String ())
117128 }
118129 if b [2 ] != 0 {
119- return nil , errors .New ("non-zero reserved field" )
130+ return c , nil , errors .New ("non-zero reserved field" )
120131 }
121132 l := 2
133+ addrType := b [3 ]
122134 var a Addr
123- switch b [ 3 ] {
135+ switch addrType {
124136 case AddrTypeIPv4 :
125137 l += net .IPv4len
126138 a .IP = make (net.IP , net .IPv4len )
@@ -129,12 +141,13 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
129141 a .IP = make (net.IP , net .IPv6len )
130142 case AddrTypeFQDN :
131143 if _ , err := io .ReadFull (c , b [:1 ]); err != nil {
132- return nil , err
144+ return c , nil , err
133145 }
134146 l += int (b [0 ])
135147 default :
136- return nil , errors .New ("unknown address type " + strconv .Itoa (int (b [3 ])))
148+ return c , nil , errors .New ("unknown address type " + strconv .Itoa (int (b [3 ])))
137149 }
150+
138151 if cap (b ) < l {
139152 b = make ([]byte , l )
140153 } else {
@@ -149,20 +162,19 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
149162 a .Name = string (b [:len (b )- 2 ])
150163 }
151164 a .Port = int (b [len (b )- 2 ])<< 8 | int (b [len (b )- 1 ])
152- return & a , nil
153- }
154165
155- func splitHostPort ( address string ) ( string , int , error ) {
156- host , port , err := net .SplitHostPort ( address )
157- if err != nil {
158- return "" , 0 , err
159- }
160- portnum , err := strconv . Atoi ( port )
161- if err != nil {
162- return "" , 0 , err
163- }
164- if 1 > portnum || portnum > 0xffff {
165- return "" , 0 , errors . New ( "port number out of range " + port )
166+ if req . Cmd == CmdUDPAssociate {
167+ var uc net.Conn
168+ if uc , err = d . proxyDial ( ctx , req . UDPNetwork , a . String ()); err != nil {
169+ return c , & a , err
170+ }
171+ c . SetDeadline ( noDeadline )
172+ go func () {
173+ defer uc . Close ()
174+ io . Copy ( io . Discard , c )
175+ }()
176+ return udpConn { Conn : uc , socksConn : c , header : udpHeader }, & a , nil
166177 }
167- return host , portnum , nil
178+
179+ return c , & a , nil
168180}
0 commit comments