슬기로운 에러 생활
[ BERT ] NotImplementedError: Layers with arguments in __init__ must override get_config
vhrehfdl
2020. 5. 26. 22:24
* 에러 메세지 : NotImplementedError: Layers with arguments in `__init__` must override `get_config`
* 에러 원인 : callback 함수를 사용해 모델을 저장하려 하는데 위에 에러가 발생하였다.
callback 함수를 사용하지 않고 model.fit을 하는 경우에는 위에 에러가 발생하지 않았다.
* 해결 방법 : BERT Layer Class 안에 get_config 함수를 추가해주니 문제가 해결되었다.
class BertLayer(tf.keras.layers.Layer):
def __init__(self, n_fine_tune_layers=10, **kwargs):
self.n_fine_tune_layers = n_fine_tune_layers
self.trainable = True
self.output_size = 768
super(BertLayer, self).__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
return config
def build(self, input_shape):
self.bert = hub.Module(
"https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1",
trainable=self.trainable,
name="{}_module".format(self.name)
)
해당 문제가 transformer등에서 자주 일어나는 것 같은데...
이 때 아래의 URL에 나온 것처럼 self.model 등의 파라미터를 업데이트 해주면 해결이 된다.