Core concepts

IREC project was originally designed to recover I/O Communication interface of WDM (Windows Driver Model) driver. The WDM interface includes the following information.

  1. All control codes implemented in the driver.
  2. InputBufferLength, OutputBufferLength constraints for each control code.

Section below describes how to recover an interface. Perhaps this is hard to understand if you don't know the concept of symbolic extension or angr.

Find DispatchDeviceControl function

DispatchDeviceControl function is a dispatcher handles IRP with I/O control code. When a user application calls the Win32 function DeviceIoControl, This function is invoked and handles a user request.

1if ( !v47 )
2{
3    v3->MajorFunction[IRP_MJ_INTERNAL_DEVICE_CONTROL] = &sub_180002470;
4    v3->MajorFunction[IRP_MJ_DEVICE_CONTROL] = &sub_180002470; // DispatchDeviceControl
5    v3->MajorFunction[IRP_MJ_CLOSE] = &sub_180002470;
6}

DispatchDeviceControl function is registered at front of DriverEntry. So we set a memory breakpoint on DriverObject→Majorfunction[MJ_DEVICE_CONTROL] using angr.

1# projects/wdm.py
2def find_DispatchDeviceControl(self):
3    ...
4	# Set a breakpoint on DriverObject->MajorFuntion[MJ_DEVICE_CONTROL]
5	state.inspect.b('mem_write',when=angr.BP_AFTER,
6	                mem_write_address=arg_driverobject+DispatchDeviceControl_OFFSET,
7	                action=self.set_major_functions)
8    ...

Now, we can get an address of DispatchDeviceControl by symbolic execution.

Obtain all control codes

To request IRP, we have to know what I/O control codes are implemented in the DispatchDeviceControl function.

As you know, DispatchDeviceControl handles IRP requests by I/O control codes. And these dispatch routines are mostly switch statement or if-else statement.

 1# explore_technique.py
 2class SwitchStateFinder(angr.ExplorationTechnique):
 3    def __init__(self, case):
 4        ...
 5
 6    def setup(self, simgr):
 7        ...
 8
 9    def step(self, simgr, stash='active', **kwargs):
10        simgr = simgr.step(stash=stash, **kwargs)
11
12        if stash == 'active' and len(simgr.stashes[stash]) > 1: # [1]
13            saved_states = [] 
14            for state in simgr.stashes[stash]:
15                try:
16                    io_code = state.solver.eval_one(self._case) # [2]
17                    if io_code in self.switch_states: # duplicated codes
18                        continue
19
20                    self.switch_states[io_code] = state
21                except:
22                    saved_states.append(state)
23
24            simgr.stashes[stash] = saved_states # [3]
25
26        return simgr
1# wdm.py
2def recovery_ioctl_interface(self):
3    ...
4    # [1]
5    state_finder = explore_technique.SwitchStateFinder(io_stack_location.fields['IoControlCode'])
6    simgr.use_technique(state_finder)
7    simgr.run()
8    ...

In this situation, we customized the behavior of a simulation manager, called exploration techniques. It sets a Symbolic Variable first to IO_STACK_LOCATION.IoControlCode. Running the symbolic execution, we check if the current state is divided into several states [1]. (A state is divided when encounter switch or if statement) And then, if the Symbolic Variable IoControlCode is narrowed down to one possible value [2], we save this state and exclude from the active stash [3]. Now, we got all I/O control codes and states.


Inspect constraints

Constraints include the following variables.

  1. InputBuffer
  2. OutputBuffer
  3. InputBufferLength
  4. OutputBufferLength

In order to get constraints for each control code, we use jump_guards of the history plugin. It is a listing of the constraints guarding each of the branches that the state has encountered.

 1# wdm.py
 2def recovery_ioctl_interface(self):
 3    ...
 4    for ioctl_code, case_state in switch_states.items():
 5        def get_constraint_states(st):
 6            ...
 7            simgr = self.project.factory.simgr(st)
 8
 9            for i in range(10):
10                simgr.step()
11
12                for state in simgr.active:
13                    for constraint in state.history.jump_guards:
14                        if 'BufferLength' in str(constraint) and \
15                            str(constraint) not in preconstraints:
16                            yield state
17
18        # Inspect what constraints are used.
19        constraint_states = get_constraint_states(case_state)

Until the constraint first appear, we step forward the state which is obtained from the previous stage.

Then, we'll get two states named sat and unsat. Only one of these has the valid conditions that we want to know.

 1# wdm.py
 2def recovery_ioctl_interface(self):
 3    ...
 4    constraint_states = get_constraint_states(case_state)
 5
 6    try:
 7        sat_state = next(constraint_states)
 8        unsat_state = next(constraint_states)
 9    except:
10        ...
11
12    simgr_sat = self.project.factory.simgr(sat_state)
13    simgr_unsat = self.project.factory.simgr(unsat_state)
14
15    def determine_unsat():
16        for _ in range(30):
17            simgr_sat.step()
18            simgr_unsat.step()
19            
20            if len(simgr_sat.active) == 0:
21                yield False
22            elif len(simgr_unsat.active) == 0:
23                yield True
24
25    if not next(determine_unsat()):
26        sat_state, unsat_state = unsat_state, sat_state

It need to determine which of two states meets the condition. Many algorithms can be used for this determination. We assume that the state of the wrong condition (unsat) ends before the correct condition (sat). Because all unsat has to do is return nt error. This algorithm is quite accurate and fast.

 1# Get valid constraints.
 2def get_valid_constraints(sat_state, unsat_state):
 3    simgr = self.project.factory.simgr(sat_state)
 4
 5    for _ in range(10):
 6        simgr.step()
 7
 8    for states in list(simgr.stashes.values()):
 9        for state in states:
10            if unsat_state.addr not in state.history.bbl_addrs:
11                return states
12
13sat_state = get_valid_constraints(sat_state, unsat_state)

When sat, unsat is determined, we can get all the correct conditions of constraints using bbl_addrs. bbl_addrs of history plugin is a listing of the basic block addresses executed by the state. So you just need to find a state that doesn't go through unsat's address.