1. 程式人生 > >How Tensorflow set device for each Operation ?

How Tensorflow set device for each Operation ?

In Tensorflow, we only need to use snippet below to assign a device to a Operation:

Python
123 withtf.device('/GPU:0'):...result=tf.matmul(a,b)

How dose it implement? Let’s take a look.

There is a mechanism called ‘context manager’ in Python. For example, we can use it to add a wrapper for a few codes:

Python
123456789101112 fromcontextlibimportcontextmanager@contextmanagerdeftag(name):print("[%s]"%name)yieldprint("[/%s]"%name)withtag("robin"):print("what")print("is")print("nature's")

The result of running this script is:

12345 [robin]whatisnature's[/robin]

Function ‘tag()’ works like a decorator. It will do something before and after those codes laying under its ‘context’.

Tensorflow uses the same principle.

Python
123456789101112 @tf_export("device")defdevice(device_name_or_function):...ifcontext.executing_eagerly():# TODO(agarwal): support device functions in EAGER mode.ifcallable(device_name_or_function):raiseRuntimeError("tf.device does not support functions when eager execution ""is enabled.")returncontext.device(device_name_or_function)else:returnget_default_graph().device(device_name_or_function)

This will call class Graph’s function ‘device()’. Its implementation:

Python
1234567891011 @tf_export("GraphKeys")classGraphKeys(object):...@tf_contextlib.contextmanagerdefdevice(self,device_name_or_function):...self._add_device_to_stack(device_name_or_function,offset=2)try:yieldfinally:self._device_function_stack.pop_obj()

The key line is ‘self._add_device_to_stack()’. Context of ‘device’ will add device name into stack of python, and when developer create an Operation it will fetch device name from stack and set it to this Operation.
Let’s check the code routine of creating Operation:

Python
123456789101112131415161718192021222324252627 @tf_export("GraphKeys")classGraphKeys(object):...defcreate_op(self,op_type,inputs,dtypes,# pylint: disable=redefined-outer-nameinput_types=None,name=None,attrs=None,op_def=None,compute_shapes=True,compute_device=True):...withself._mutation_lock():ret=Operation(node_def,self,inputs=inputs,output_types=dtypes,control_inputs=control_inputs,input_types=input_types,original_op=self._default_original_op,op_def=op_def)self._create_op_helper(ret,compute_device=compute_device)returnret
Python
1234 def_create_op_helper(self,op,compute_device=True):...ifcompute_device:self._apply_device_functions(op)
Python
1234567 def_apply_device_functions(self,op):...fordevice_spec inself._device_function_stack.peek_objs():ifdevice_spec.function isNone:breakop._set_device(device_spec.function(op))op._device_code_locations=self._snapshot_device_function_stack_metadata()

‘self._device_function_stack.peek_objs’ is where it peek the device name from stack.