Python decorator register

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
mapping = {}

def register(name):
	def wrapper(function):
		mapping[name] = function
		return function
	return wrapper

def get_neural_network_architecture(architecture_name):
	if architecture_name in mapping:
		return mapping[architecture_name] 
	else:
		raise ValueError('Unknown architecture name: {}'.format(architecture_name))

@register("mlp") 
def mlp(hidden_layers):
	print('running mlp(hidden_layers) ' + str(hidden_layers))

@register("cnn") 
def cnn(hidden_layers, stride): 
	print('running cnn(hidden_layers, stride) ' + str(stride))


if __name__ == '__main__':
	#from ... import get_neural_network_architecture
	f = get_neural_network_architecture("cnn")
	print(f)
	f(2, 1) 
1
2
3
4
Out:

<function cnn at 0x00000219BA78A950>
running cnn(hidden_layers, stride) 1