1+ """ SASL transports for Thrift. """
2+
3+ from thrift .transport .TTransport import CReadableTransport , TTransportBase , TTransportException , StringIO
4+ import struct
5+
6+ class TSaslClientTransport (TTransportBase , CReadableTransport ):
7+ START = 1
8+ OK = 2
9+ BAD = 3
10+ ERROR = 4
11+ COMPLETE = 5
12+
13+ def __init__ (self , sasl_client_factory , mechanism , trans ):
14+ """
15+ @param sasl_client_factory: a callable that returns a new sasl.Client object
16+ @param mechanism: the SASL mechanism (e.g. "GSSAPI", "PLAIN")
17+ @param trans: the underlying transport over which to communicate.
18+ """
19+ self ._trans = trans
20+ self .sasl_client_factory = sasl_client_factory
21+ self .sasl = None
22+ self .mechanism = mechanism
23+ self .__wbuf = StringIO ()
24+ self .__rbuf = StringIO ()
25+ self .opened = False
26+ self .encode = None
27+
28+ def isOpen (self ):
29+ return self ._trans .isOpen ()
30+
31+ def open (self ):
32+ if not self ._trans .isOpen ():
33+ self ._trans .open ()
34+
35+ if self .sasl is not None :
36+ raise TTransportException (
37+ type = TTransportException .NOT_OPEN ,
38+ message = "Already open!" )
39+ self .sasl = self .sasl_client_factory
40+
41+ ret , chosen_mech , initial_response = self .sasl .start (self .mechanism )
42+ if not ret :
43+ raise TTransportException (type = TTransportException .NOT_OPEN ,
44+ message = ("Could not start SASL: %s" % self .sasl .getError ()))
45+
46+ # Send initial response
47+ self ._send_message (self .START , chosen_mech )
48+ self ._send_message (self .OK , initial_response )
49+
50+ # SASL negotiation loop
51+ while True :
52+ status , payload = self ._recv_sasl_message ()
53+ if status not in (self .OK , self .COMPLETE ):
54+ raise TTransportException (type = TTransportException .NOT_OPEN ,
55+ message = ("Bad status: %d (%s)" % (status , payload )))
56+ if status == self .COMPLETE :
57+ break
58+ ret , response = self .sasl .step (payload )
59+ if not ret :
60+ raise TTransportException (type = TTransportException .NOT_OPEN ,
61+ message = ("Bad SASL result: %s" % (self .sasl .getError ())))
62+ self ._send_message (self .OK , response )
63+
64+ def _send_message (self , status , body ):
65+ header = struct .pack (">BI" , status , len (body ))
66+ self ._trans .write (header + body )
67+ self ._trans .flush ()
68+
69+ def _recv_sasl_message (self ):
70+ header = self ._trans .readAll (5 )
71+ status , length = struct .unpack (">BI" , header )
72+ if length > 0 :
73+ payload = self ._trans .readAll (length )
74+ else :
75+ payload = ""
76+ return status , payload
77+
78+ def write (self , data ):
79+ self .__wbuf .write (data )
80+
81+ def flush (self ):
82+ buffer = self .__wbuf .getvalue ()
83+ # The first time we flush data, we send it to sasl.encode()
84+ # If the length doesn't change, then we must be using a QOP
85+ # of auth and we should no longer call sasl.encode(), otherwise
86+ # we encode every time.
87+ if self .encode == None :
88+ success , encoded = self .sasl .encode (buffer )
89+ if not success :
90+ raise TTransportException (type = TTransportException .UNKNOWN ,
91+ message = self .sasl .getError ())
92+ if (len (encoded )== len (buffer )):
93+ self .encode = False
94+ self ._flushPlain (buffer )
95+ else :
96+ self .encode = True
97+ self ._trans .write (encoded )
98+ elif self .encode :
99+ self ._flushEncoded (buffer )
100+ else :
101+ self ._flushPlain (buffer )
102+
103+ self ._trans .flush ()
104+ self .__wbuf = StringIO ()
105+
106+ def _flushEncoded (self , buffer ):
107+ # sasl.ecnode() does the encoding and adds the length header, so nothing
108+ # to do but call it and write the result.
109+ success , encoded = self .sasl .encode (buffer )
110+ if not success :
111+ raise TTransportException (type = TTransportException .UNKNOWN ,
112+ message = self .sasl .getError ())
113+ self ._trans .write (encoded )
114+
115+ def _flushPlain (self , buffer ):
116+ # When we have QOP of auth, sasl.encode() will pass the input to the output
117+ # but won't put a length header, so we have to do that.
118+
119+ # Note stolen from TFramedTransport:
120+ # N.B.: Doing this string concatenation is WAY cheaper than making
121+ # two separate calls to the underlying socket object. Socket writes in
122+ # Python turn out to be REALLY expensive, but it seems to do a pretty
123+ # good job of managing string buffer operations without excessive copies
124+ self ._trans .write (struct .pack (">I" , len (buffer )) + buffer )
125+
126+ def read (self , sz ):
127+ ret = self .__rbuf .read (sz )
128+ if len (ret ) != 0 :
129+ return ret
130+
131+ self ._read_frame ()
132+ return self .__rbuf .read (sz )
133+
134+ def _read_frame (self ):
135+ header = self ._trans .readAll (4 )
136+ (length ,) = struct .unpack (">I" , header )
137+ if self .encode :
138+ # If the frames are encoded (i.e. you're using a QOP of auth-int or
139+ # auth-conf), then make sure to include the header in the bytes you send to
140+ # sasl.decode()
141+ encoded = header + self ._trans .readAll (length )
142+ success , decoded = self .sasl .decode (encoded )
143+ if not success :
144+ raise TTransportException (type = TTransportException .UNKNOWN ,
145+ message = self .sasl .getError ())
146+ else :
147+ # If the frames are not encoded, just pass it through
148+ decoded = self ._trans .readAll (length )
149+ self .__rbuf = StringIO (decoded )
150+
151+ def close (self ):
152+ self ._trans .close ()
153+ self .sasl = None
154+
155+ # Implement the CReadableTransport interface.
156+ # Stolen shamelessly from TFramedTransport
157+ @property
158+ def cstringio_buf (self ):
159+ return self .__rbuf
160+
161+ def cstringio_refill (self , prefix , reqlen ):
162+ # self.__rbuf will already be empty here because fastbinary doesn't
163+ # ask for a refill until the previous buffer is empty. Therefore,
164+ # we can start reading new frames immediately.
165+ while len (prefix ) < reqlen :
166+ self ._read_frame ()
167+ prefix += self .__rbuf .getvalue ()
168+ self .__rbuf = StringIO (prefix )
169+ return self .__rbuf
0 commit comments