Skip to content

Commit

Permalink
Achieve better timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
rwitten committed Apr 27, 2024
1 parent 18ba1a7 commit 2522032
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):

_buffered_step = None
_buffered_metrics = None
_last_buffer_time = None
_buffered_lr = None


def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
def write_metrics(writer, per_device_tflops, lr, local_metrics_file, running_gcs_metrics, metrics, step, config):
"""Entry point for all metrics writing in Train's Main.
TODO: would be better as a Class in the future (that initialized all state!)
Expand All @@ -112,11 +114,23 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics
global _buffered_step, _buffered_metrics, _last_buffer_time, _buffered_lr


if _buffered_metrics is not None:
jax.block_until_ready(_buffered_metrics)

next_buffer_time = datetime.datetime.now()

if _buffered_metrics is not None:
if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
if _last_buffer_time is None:
raise ValueError(f"When writing metrics, {_last_buffer_time=} was none")
if _buffered_lr is None:
raise ValueError(f"When writing metrics, {_buffered_lr=} was none")

record_scalar_metrics(_buffered_metrics, (next_buffer_time - _last_buffer_time), per_device_tflops, _buffered_lr)
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)

if config.metrics_file:
Expand All @@ -127,6 +141,8 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step

_buffered_step = step
_buffered_metrics = metrics
_last_buffer_time = next_buffer_time
_buffered_lr = lr


def write_metrics_to_tensorboard(writer, metrics, step, config):
Expand Down Expand Up @@ -486,7 +502,6 @@ def train_loop(config, state=None):
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)

example_batch = None
last_step_completion = datetime.datetime.now()

for step in np.arange(start_step, config.steps):
if step == first_profiling_step:
Expand All @@ -500,9 +515,6 @@ def train_loop(config, state=None):
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)

new_time = datetime.datetime.now()
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
last_step_completion = new_time

if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, step, state, config.dataset_type, data_iterator):
Expand All @@ -513,7 +525,8 @@ def train_loop(config, state=None):
checkpoint_manager.wait_until_finished()
sys.exit()

write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
write_metrics(writer, per_device_tflops, learning_rate_schedule(step),local_metrics_file, running_gcs_metrics, metrics,
step, config)

if config.eval_interval > 0 and step > start_step and step % config.eval_interval == 0:
assert eval_data_iterator
Expand Down

0 comments on commit 2522032

Please sign in to comment.