开发者

Warnings while saving model

开发者 https://www.devze.com 2022-12-07 19:56 出处:网络
I am using the following model for machine translation using transformer with KerasNLP: ` # Encoder encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs"

I am using the following model for machine translation using transformer with KerasNLP: `

# Encoder
encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")

x = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=ENG_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
    mask_zero=True,
)(encoder_inputs)

encoder_outputs = keras_nlp.layers.TransformerEncoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(inputs=x)
encoder = keras.Model(encoder_inputs, encoder_outputs)


# Decoder
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM), name="decoder_state_inputs")

x = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=SND_VOCAB_SIZE,
    sequence_length=MAX_SEQUENCE_LENGTH,
    embedding_dim=EMBED_DIM,
    mask_zero=True,
)(decoder_inputs)

x = keras_nlp.layers.TransformerDecoder(
    intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
x = keras.layers.Dropout(0.5)(x)
decoder_outputs = keras.layers.Dense(SND_VOCAB_SIZE, activation="softmax")(x)
decoder = keras.Model(
    [
        decoder_inputs,
        encoded_seq_inputs,
    ],
    decoder_outputs,
)
decoder_outputs = decoder([decoder_inputs, encoder_outputs])

transformer = keras.Model(
    [encoder_inputs, decoder_inputs],
    decoder_outputs,
    name="transformer",
)

`

`

transformer.summary()
transformer.compile(
    "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
hist = transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)

`

When I am trying to save the model it is giving me warnings and it is not predicting correctly:

`

with open('model.pkl', 'wb') as file:
  pickle.dump(transformer, file)

`

WARNING:absl:Found untraced functions such as token_embedding1_layer_call_fn, token_embedd开发者_开发知识库ing1_layer_call_and_return_conditional_losses, position_embedding1_layer_call_fn, position_embedding1_layer_call_and_return_conditional_losses, multi_head_attention_layer_call_fn while saving (showing 5 of 78). These functions will not be directly callable after loading.

0

精彩评论

暂无评论...
验证码 换一张
取 消

关注公众号