chore: fix CustomCallback and isinstance of History
This commit is contained in:
parent
1e6d04c816
commit
8274609ad1
@ -1,6 +1,15 @@
|
|||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.13"
|
||||||
|
# dependencies = [
|
||||||
|
# "keras==3.8.0",
|
||||||
|
# "marimo",
|
||||||
|
# "numpy==2.2.2",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
import marimo
|
import marimo
|
||||||
|
|
||||||
__generated_with = "0.10.16"
|
__generated_with = "0.10.17"
|
||||||
app = marimo.App(width="medium")
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
@ -458,15 +467,21 @@ def f1_score_metric(PrecisionMetric, RecallMetric, keras, mo, tf):
|
|||||||
|
|
||||||
|
|
||||||
@app.cell(hide_code=True)
|
@app.cell(hide_code=True)
|
||||||
def custom_validation_metrics(X_val, mo, tf, y_val):
|
def custom_validation_metrics(mo, tf):
|
||||||
#Custom callback to compute metrics on validation data
|
#Custom callback to compute metrics on validation data
|
||||||
class CustomValidationMetrics(tf.keras.callbacks.Callback):
|
class CustomValidationMetrics(tf.keras.callbacks.Callback):
|
||||||
|
def __init__(self, X_val, y_val):
|
||||||
|
super().__init__() # Initialize the parent class
|
||||||
|
self.X_val = X_val
|
||||||
|
self.y_val = y_val
|
||||||
|
|
||||||
def on_epoch_end(self, epoch, logs=None):
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
val_predictions = self.model.predict(X_val, verbose=0)
|
val_predictions = self.model.predict(self.X_val, verbose=0)
|
||||||
val_predictions = (val_predictions > 0.5).astype(int) # Binarize predictions
|
val_predictions = (val_predictions > 0.5).astype(int) # Binarize predictions
|
||||||
|
|
||||||
precision = tf.keras.metrics.Precision()(y_val, val_predictions)
|
# Compute precision, recall, and f1-score
|
||||||
recall = tf.keras.metrics.Recall()(y_val, val_predictions)
|
precision = tf.keras.metrics.Precision()(self.y_val, val_predictions)
|
||||||
|
recall = tf.keras.metrics.Recall()(self.y_val, val_predictions)
|
||||||
f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
|
f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
|
||||||
|
|
||||||
print(f"\nEpoch {epoch + 1} Validation Metrics - Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}")
|
print(f"\nEpoch {epoch + 1} Validation Metrics - Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}")
|
||||||
@ -543,7 +558,7 @@ def show_train_model_shape(mo, perfspec, prepare_train, verbose):
|
|||||||
|
|
||||||
|
|
||||||
@app.cell(hide_code=True)
|
@app.cell(hide_code=True)
|
||||||
def make_model(mo, np, perfspec):
|
def make_model(CustomValidationMetrics, mo, np, perfspec):
|
||||||
# Define the LSTM model
|
# Define the LSTM model
|
||||||
def make_model(X=[],y=[],label_encoder=[], encoded_actions=[]):
|
def make_model(X=[],y=[],label_encoder=[], encoded_actions=[]):
|
||||||
if len(X) == 0 or len(y) == 0:
|
if len(X) == 0 or len(y) == 0:
|
||||||
@ -569,13 +584,13 @@ def make_model(mo, np, perfspec):
|
|||||||
perfspec['vars']['model'] = Sequential(
|
perfspec['vars']['model'] = Sequential(
|
||||||
[
|
[
|
||||||
#Embedding(input_dim=vocab_size, output_dim=embedding_dim),
|
#Embedding(input_dim=vocab_size, output_dim=embedding_dim),
|
||||||
|
Input(shape=(perfspec['settings']['sequence_length'], 1)),
|
||||||
LSTM(
|
LSTM(
|
||||||
perfspec['settings']['lstm_units_1'],
|
perfspec['settings']['lstm_units_1'],
|
||||||
return_sequences=True,
|
return_sequences=True,
|
||||||
recurrent_dropout=perfspec['settings']['dropout_rate'],
|
recurrent_dropout=perfspec['settings']['dropout_rate'],
|
||||||
#input_shape = (2,vocab_size),
|
input_shape=(perfspec['settings']['sequence_length'], 1),
|
||||||
),
|
),
|
||||||
Input(shape=(perfspec['settings']['sequence_length'], 1)),
|
|
||||||
LSTM(
|
LSTM(
|
||||||
perfspec['settings']['lstm_units_2'],
|
perfspec['settings']['lstm_units_2'],
|
||||||
return_sequences=False,
|
return_sequences=False,
|
||||||
@ -622,11 +637,12 @@ def make_model(mo, np, perfspec):
|
|||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
monitor="val_loss", patience=5, restore_best_weights=True
|
monitor="val_loss", patience=10, restore_best_weights=True
|
||||||
)
|
)
|
||||||
lr_reduction = ReduceLROnPlateau(
|
lr_reduction = ReduceLROnPlateau(
|
||||||
monitor="val_loss", patience=3, factor=0.5, min_lr=0.0001
|
monitor="val_loss", patience=8, factor=0.8, min_lr=0.0001
|
||||||
)
|
)
|
||||||
|
custom_metrics_callback = CustomValidationMetrics(X, y)
|
||||||
if perfspec['settings']['checkpoint_mode'] == "weights":
|
if perfspec['settings']['checkpoint_mode'] == "weights":
|
||||||
# Save only the weights of the model instead of the full model.
|
# Save only the weights of the model instead of the full model.
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
@ -645,8 +661,9 @@ def make_model(mo, np, perfspec):
|
|||||||
verbose=1 # Print messages when saving
|
verbose=1 # Print messages when saving
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks=[early_stopping,lr_reduction] #,CustomValidationMetrics]
|
callbacks=[early_stopping,lr_reduction]
|
||||||
callbacks=[] #,CustomValidationMetrics]
|
callbacks=[early_stopping,lr_reduction]
|
||||||
|
callbacks.append(custom_metrics_callback)
|
||||||
if checkpoint != None:
|
if checkpoint != None:
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
|
|
||||||
@ -676,7 +693,7 @@ def make_model(mo, np, perfspec):
|
|||||||
|
|
||||||
This is where **model** is creates and **fit**
|
This is where **model** is creates and **fit**
|
||||||
|
|
||||||
Saved in `perfspec['vars'] as `model` and `history`
|
Saved in `perfspec['vars']` as `model` and `history`
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
return (make_model,)
|
return (make_model,)
|
||||||
@ -791,10 +808,11 @@ def perfspec_save_model(Path, mo, perfspec):
|
|||||||
def perfspec_plot_history(Path, mo):
|
def perfspec_plot_history(Path, mo):
|
||||||
def plot_history(perfspec):
|
def plot_history(perfspec):
|
||||||
import json
|
import json
|
||||||
|
from keras.src.callbacks import History
|
||||||
if 'vars' not in perfspec:
|
if 'vars' not in perfspec:
|
||||||
return None
|
return None
|
||||||
if perfspec['vars']['history'] != None:
|
if perfspec['vars']['history'] != None:
|
||||||
if 'history' in perfspec['vars']['history']:
|
if isinstance(perfspec['vars']['history'], History):
|
||||||
_model_history = perfspec['vars']['history'].history
|
_model_history = perfspec['vars']['history'].history
|
||||||
else:
|
else:
|
||||||
_model_history = perfspec['vars']['history']
|
_model_history = perfspec['vars']['history']
|
||||||
@ -993,8 +1011,9 @@ def perfspec_evaluate_model(Path, mo, np, prepare_train):
|
|||||||
|
|
||||||
def history_info(perfspec):
|
def history_info(perfspec):
|
||||||
import json
|
import json
|
||||||
|
from keras.src.callbacks import History
|
||||||
if perfspec['vars']['history'] != None:
|
if perfspec['vars']['history'] != None:
|
||||||
if 'history' in perfspec['vars']['history']:
|
if isinstance(perfspec['vars']['history'], History):
|
||||||
model_history = perfspec['vars']['history'].history
|
model_history = perfspec['vars']['history'].history
|
||||||
else:
|
else:
|
||||||
model_history = perfspec['vars']['history']
|
model_history = perfspec['vars']['history']
|
||||||
@ -1007,7 +1026,7 @@ def perfspec_evaluate_model(Path, mo, np, prepare_train):
|
|||||||
model_history = json.load(history_file)
|
model_history = json.load(history_file)
|
||||||
if model_history != None:
|
if model_history != None:
|
||||||
from prettytable import PrettyTable
|
from prettytable import PrettyTable
|
||||||
rain_loss = model_history['loss']
|
train_loss = model_history['loss']
|
||||||
val_loss = model_history['val_loss']
|
val_loss = model_history['val_loss']
|
||||||
train_acc = model_history['accuracy']
|
train_acc = model_history['accuracy']
|
||||||
val_acc = model_history['val_accuracy']
|
val_acc = model_history['val_accuracy']
|
||||||
|
Loading…
Reference in New Issue
Block a user