How Tensorflow set device for each Operation ?
In Tensorflow, we only need to use snippet below to assign a device to a Operation:
Python123 | withtf.device('/GPU:0'):...result=tf.matmul(a,b) |
How dose it implement? Let’s take a look.
There is a mechanism called ‘context manager’ in Python. For example, we can use it to add a wrapper for a few codes:
Python123456789101112 | fromcontextlibimportcontextmanager@contextmanagerdeftag(name):print("[%s]"%name)yieldprint("[/%s]"%name)withtag("robin"):print("what")print("is")print("nature's") |
The result of running this script is:
12345 | [robin]whatisnature's[/robin] |
Function ‘tag()’ works like a decorator. It will do something before and after those codes laying under its ‘context’.
Tensorflow uses the same principle.
Python123456789101112 | @tf_export("device")defdevice(device_name_or_function):...ifcontext.executing_eagerly():# TODO(agarwal): support device functions in EAGER mode.ifcallable(device_name_or_function):raiseRuntimeError("tf.device does not support functions when eager execution ""is enabled.")returncontext.device(device_name_or_function)else:returnget_default_graph().device(device_name_or_function) |
This will call class Graph’s function ‘device()’. Its implementation:
Python1234567891011 | @tf_export("GraphKeys")classGraphKeys(object):...@tf_contextlib.contextmanagerdefdevice(self,device_name_or_function):...self._add_device_to_stack(device_name_or_function,offset=2)try:yieldfinally:self._device_function_stack.pop_obj() |
The key line is ‘self._add_device_to_stack()’. Context of ‘device’ will add device name into stack of python, and when developer create an Operation it will fetch device name from stack and set it to this Operation.
Let’s check the code routine of creating Operation:
123456789101112131415161718192021222324252627 | @tf_export("GraphKeys")classGraphKeys(object):...defcreate_op(self,op_type,inputs,dtypes,# pylint: disable=redefined-outer-nameinput_types=None,name=None,attrs=None,op_def=None,compute_shapes=True,compute_device=True):...withself._mutation_lock():ret=Operation(node_def,self,inputs=inputs,output_types=dtypes,control_inputs=control_inputs,input_types=input_types,original_op=self._default_original_op,op_def=op_def)self._create_op_helper(ret,compute_device=compute_device)returnret |
1234 | def_create_op_helper(self,op,compute_device=True):...ifcompute_device:self._apply_device_functions(op) |
1234567 | def_apply_device_functions(self,op):...fordevice_spec inself._device_function_stack.peek_objs():ifdevice_spec.function isNone:breakop._set_device(device_spec.function(op))op._device_code_locations=self._snapshot_device_function_stack_metadata() |
‘self._device_function_stack.peek_objs’ is where it peek the device name from stack.