Source code for pdfbl.sequential.sequential_cmi_runner

import json
import re
import threading
import time
import warnings
from pathlib import Path
from queue import Queue
from types import SimpleNamespace
from typing import Literal

from bg_mpl_stylesheets.styles import all_styles
from matplotlib import pyplot as plt
from prompt_toolkit import PromptSession
from prompt_toolkit.patch_stdout import patch_stdout

from pdfbl.sequential.pdfadapter import PDFAdapter

plt.style.use(all_styles["bg-style"])


[docs] class SequentialCMIRunner: def __init__(self): self.input_files_known = [] self.input_files_completed = [] self.input_files_running = [] self.adapter = PDFAdapter() self.visualization_data = {} def _validate_inputs(self): for path_name in [ "input_data_dir", "output_result_dir", ]: if not Path(self.inputs[path_name]).exists(): raise FileNotFoundError( f"Path '{self.inputs[path_name]}' for " f"'{path_name}' does not exist. Please check the " "provided path." ) if not Path(self.inputs[path_name]).is_dir(): raise NotADirectoryError( f"Path '{self.inputs[path_name]}' for " f"'{path_name}' is not a directory. Please check the " "provided path." ) if not Path(self.inputs["structure_path"]).exists(): raise FileNotFoundError( f"Structure file '{self.inputs['structure_path']}' does not " "exist. Please check the provided path." ) profile_files = list(Path(self.inputs["input_data_dir"]).glob("*")) if len(profile_files) > 0: # skip variable checking if no input files for tmp_file_path in profile_files: matches = re.findall( self.inputs["filename_order_pattern"], tmp_file_path.name ) if len(matches) == 0: raise ValueError( f"Input file '{tmp_file_path}' does not match the " "filename order pattern. Please check the pattern " "or the input files." ) tmp_adatper = PDFAdapter() tmp_adatper.initialize_profile(str(tmp_file_path)) tmp_adatper.initialize_structures([self.inputs["structure_path"]]) tmp_adatper.initialize_contribution() tmp_adatper.initialize_recipe() allowed_variable_names = list( tmp_adatper.recipe._parameters.keys() ) for var_name in self.inputs["refinable_variable_names"]: if var_name not in allowed_variable_names: raise ValueError( f"Refinable variable '{var_name}' not found in the " "recipe. Please choose from the existing variables: " f"{allowed_variable_names}" ) for var_name in self.inputs.get("plot_variable_names", []): if var_name not in allowed_variable_names: raise ValueError( f"Variable '{var_name}' is not found in the recipe. " "Please choose from the existing variables: " f"{allowed_variable_names}" ) else: warnings.warn( "No input profile files found in the input data directory. " "Skipping variable name validation." ) allowed_result_entry_names = [ "residual", "contributions", "restraints", "chi2", "reduced_chi2", ] for entry_name in self.inputs.get("plot_result_names", []): if entry_name not in allowed_result_entry_names: raise ValueError( f"Result entry '{entry_name}' is not a valid entry to " "plot. Please choose from the following entries: " f"{allowed_result_entry_names}" ) for entry_name in self.inputs.get( "plot_intermediate_result_names", [] ): if entry_name not in allowed_result_entry_names: raise ValueError( f"Intermediate result '{entry_name}' is not a valid " "entry to plot. Please choose from the following " "entries: " f"{allowed_result_entry_names}" )
[docs] def load_inputs( self, input_data_dir, structure_path, output_result_dir="results", filename_order_pattern=r"(\d+)K\.gr", whether_plot_y=False, whether_plot_ycalc=False, plot_variable_names=None, plot_result_names=None, plot_intermediate_result_names=None, refinable_variable_names=None, initial_variable_values=None, xmin=None, xmax=None, dx=None, qmin=None, qmax=None, show_plot=True, ): """Load and validate input configuration for sequential PDF refinement. This method initializes the sequential CMI runner with input data, structure information, and refinement parameters, and the plotting configuration. Parameters ---------- input_data_dir : str The path to the directory containing input PDF profile files. structure_path : str The path to the structure file (e.g., CIF format) used for refinement. output_result_dir : str The path to the directory for storing refinement results. Default is "results". filename_order_pattern : str The regular expression pattern to extract ordering information from filenames. Default is r"(\d+)K\.gr" to extract temperature values from filenames. refinable_variable_names : list of str The list of variable names to refine. Must exist in the recipe. Default variable names are all possible variables that can be created from the input structure and profile. initial_variable_values : dict The dictionary mapping variable names to their initial values. Default is None. xmin : float The minimum x-value for the PDF profile. Default is the value parsed from the input file. xmax : float The maximum x-value for the PDF profile. Default is the value parsed from the input file. dx : float The step size for the PDF profile. Default is the value parsed from the input file. qmin : float The minimum q-value for the PDF profile. Default is the value parsed from the input file. qmax : float The maximum q-value for the PDF profile. Default is the value parsed from the input file. show_plot : bool Whether to display plots during refinement. Default is True. whether_plot_y : bool Whether to plot the experimental PDF data (y). Default is False. whether_plot_ycalc : bool Whether to plot the calculated PDF data (ycalc). Default is False. plot_variable_names : list of str The list of variable names to plot during refinement. Default is None. plot_result_names : list of str The list of fit result entries to plot. Allowed values: "residual", "contributions", "restraints", "chi2", "reduced_chi2". Default is None. plot_intermediate_result_names : list of str The list of intermediate result entries to plot during refinement. Allowed values: "residual", "contributions", "restraints", "chi2", "reduced_chi2". Default is None. Raises ------ FileNotFoundError If the input data directory, output result directory, or structure file does not exist. NotADirectoryError If input_data_dir or output_result_dir is not a directory. ValueError If a refinable variable name is not found in the recipe, or if a plot result name is not valid. Examples -------- >>> runner = SequentialCMIRunner() >>> runner.load_inputs( ... input_data_dir="./data", ... structure_path="./structure.cif", ... output_result_dir="./results", ... refinable_variable_names=["a", "all"], ... plot_variable_names=["a"], ... plot_result_names=["chi2"], ... plot_intermediate_result_names=["residual"], ... ) """ # noqa: W605 self.inputs = { "input_data_dir": input_data_dir, "structure_path": structure_path, "output_result_dir": output_result_dir, "filename_order_pattern": filename_order_pattern, "xmin": xmin, "xmax": xmax, "dx": dx, "qmin": qmin, "qmax": qmax, "refinable_variable_names": refinable_variable_names or [], "initial_variable_values": initial_variable_values or {}, "whether_plot_y": whether_plot_y, "whether_plot_ycalc": whether_plot_ycalc, "plot_variable_names": plot_variable_names or [], "plot_result_names": plot_result_names or [], "plot_intermediate_result_names": plot_intermediate_result_names or [], } self.show_plot = show_plot self._validate_inputs() self._initialize_plots()
def _initialize_plots(self): whether_plot_y = self.inputs["whether_plot_y"] whether_plot_ycalc = self.inputs["whether_plot_ycalc"] plot_variable_names = self.inputs["plot_variable_names"] plot_result_names = self.inputs["plot_result_names"] plot_intermediate_result_names = self.inputs[ "plot_intermediate_result_names" ] if whether_plot_y and whether_plot_ycalc: fig, _ = plt.subplots(2, 1) label = ["ycalc", "y"] elif whether_plot_ycalc or whether_plot_y: fig, _ = plt.subplots() if whether_plot_ycalc: label = ["ycalc"] else: label = ["y"] else: fig = None if fig: axes = fig.axes lines = [] for i in range(len(axes)): (line,) = axes[i].plot( [], [], label=label[i], color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i], ) lines.append(line) self.visualization_data[label[i]] = { "line": line, "xdata": Queue(), "ydata": Queue(), } fig.legend() names = ["variables", "results", "intermediate_results"] plot_tasks = [ plot_variable_names, plot_result_names, plot_intermediate_result_names, ] for i in range(len(plot_tasks)): if plot_tasks[i] is not None: self.visualization_data[names[i]] = {} for var_name in plot_tasks[i]: fig, ax = plt.subplots() (line,) = ax.plot([], [], label=var_name, marker="o") self.visualization_data[names[i]][var_name] = { "line": line, "buffer": [], "ydata": Queue(), } fig.suptitle(f"{names[i].capitalize()}: {var_name}") if plot_intermediate_result_names is not None: for var_name in plot_intermediate_result_names: self.adapter.monitor_intermediate_results( var_name, step=10, queue=self.visualization_data["intermediate_results"][ var_name ]["ydata"], ) def _update_plot(self): for key, plot_pack in self.visualization_data.items(): if key in ["ycalc", "y"]: if not plot_pack["xdata"].empty(): line = plot_pack["line"] xdata = plot_pack["xdata"].get() ydata = plot_pack["ydata"].get() line.set_xdata(xdata) line.set_ydata(ydata) line.axes.relim() line.axes.autoscale_view() elif ( key == "variables" or key == "results" or key == "intermediate_results" ): for _, data_pack in plot_pack.items(): if not data_pack["ydata"].empty(): line = data_pack["line"] buffer = data_pack["buffer"] new_y = data_pack["ydata"].get() buffer.append(new_y) xdata = list(range(1, len(buffer) + 1)) ydata = buffer line.set_xdata(xdata) line.set_ydata(ydata) line.axes.relim() line.axes.autoscale_view() def _check_for_new_data(self): input_data_dir = self.inputs["input_data_dir"] filename_order_pattern = self.inputs["filename_order_pattern"] files = [file for file in Path(input_data_dir).glob("*")] sorted_file = sorted( files, key=lambda file: int( re.findall(filename_order_pattern, file.name)[0] ), ) if ( self.input_files_known != sorted_file[: len(self.input_files_known)] ): raise RuntimeError( "Wrong order to run sequential toolset is detected. " "This is likely due to files appearing in the input directory " "in the wrong order. Please restart the sequential toolset." ) if self.input_files_known == sorted_file: return self.input_files_known = sorted_file self.input_files_running = [ f for f in self.input_files_known if f not in self.input_files_completed ] print(f"{[str(f) for f in self.input_files_running]} detected.")
[docs] def set_start_input_file( self, input_filename, input_filename_to_result_filename ): """Set the starting input file for sequential refinement and continue the interrupted sequential refinement from that point. Parameters ---------- input_filename : str The name of the input file to start from. This file must be in the input data directory. input_filename_to_result_filename : function The function that takes an input filename and returns the corresponding result filename. This is used to locate the last result file for loading variable values. """ self._check_for_new_data() input_file_path = Path(self.inputs["input_data_dir"]) / input_filename if input_file_path not in self.input_files_known: raise ValueError( f"Input file {input_filename} not found in known input files." ) start_index = self.input_files_known.index(input_file_path) self.input_files_completed = self.input_files_known[:start_index] self.input_files_running = self.input_files_known[start_index:] last_result_file = input_filename_to_result_filename( self.input_files_completed[-1].name ) last_result_file = ( Path(self.inputs["output_result_dir"]) / last_result_file ) if not Path(last_result_file).exists(): raise FileNotFoundError( f"Result file {last_result_file} not found. " "Cannot load last result variable values. " "Please check the provided function or use " "an earlier input file." ) last_result_variables_values = json.load(open(last_result_file, "r"))[ "variables" ] last_result_variables_values = { name: pack["value"] for name, pack in last_result_variables_values.items() } self.last_result_variables_values = last_result_variables_values print(f"Starting from input file: {self.input_files_running[0].name}")
def _run_one_cycle(self, stop_event=SimpleNamespace(is_set=lambda: False)): self._check_for_new_data() xmin = self.inputs["xmin"] xmax = self.inputs["xmax"] dx = self.inputs["dx"] qmin = self.inputs["qmin"] qmax = self.inputs["qmax"] structure_path = self.inputs["structure_path"] output_result_dir = self.inputs["output_result_dir"] initial_variable_values = self.inputs["initial_variable_values"] refinable_variable_names = self.inputs["refinable_variable_names"] if not self.input_files_running: return None for input_file in self.input_files_running: if stop_event.is_set(): break print(f"Processing {input_file.name}...") self.adapter.initialize_profile( str(input_file), xmin=xmin, xmax=xmax, dx=dx, qmin=qmin, qmax=qmax, ) self.adapter.initialize_structures([structure_path]) self.adapter.initialize_contribution() self.adapter.initialize_recipe() if not hasattr(self, "last_result_variables_values"): self.last_result_variables_values = initial_variable_values self.adapter.set_initial_variable_values( self.last_result_variables_values ) if refinable_variable_names is None: refinable_variable_names = list(initial_variable_values.keys()) self.adapter.refine_variables(refinable_variable_names) results = self.adapter.save_results( filename=str( Path(output_result_dir) / f"{input_file.stem}_result.json" ), mode="dict", ) self.last_result_variables_values = { name: pack["value"] for name, pack in results["variables"].items() } self.input_files_completed.append(input_file) if "ycalc" in self.visualization_data: xdata = self.adapter.recipe.pdfcontribution.profile.x ydata = self.adapter.recipe.pdfcontribution.profile.ycalc self.visualization_data["ycalc"]["xdata"].put(xdata) self.visualization_data["ycalc"]["ydata"].put(ydata) if "y" in self.visualization_data: xdata = self.adapter.recipe.pdfcontribution.profile.x ydata = self.adapter.recipe.pdfcontribution.profile.y self.visualization_data["y"]["xdata"].put(xdata) self.visualization_data["y"]["ydata"].put(ydata) for var_name in self.visualization_data.get("variables", {}): new_value = self.adapter.recipe._parameters[var_name].value self.visualization_data["variables"][var_name]["ydata"].put( new_value ) for entry_name in self.visualization_data.get("results", {}): fitresults_dict = self.adapter.save_results(mode="dict") entry_value = fitresults_dict.get(entry_name, None) self.visualization_data["results"][entry_name]["ydata"].put( entry_value ) print("Completed!") self.input_files_running = []
[docs] def run(self, mode: Literal["batch", "stream"]): """Run the sequential refinement process in either batch or streaming mode. Parameters ---------- mode : str The mode to run the sequential refinement. Must be either "batch" or "stream". In "batch" mode, the toolset will run through all available input files once and then stop. In "stream" mode, the runner will continuously monitor the input data directory for new files and process them as they appear, until the user decides to stop the process. """ if mode == "batch": self._run_one_cycle() self._update_plot() elif mode == "stream": stop_event = threading.Event() session = PromptSession() if (self.visualization_data is not None) and self.show_plot: plt.ion() plt.pause(0.01) def stream_loop(): while not stop_event.is_set(): self._run_one_cycle(stop_event) stop_event.wait(1) # Check for new data every 1s def input_loop(): with patch_stdout(): print("=== COMMANDS ===") print("Type STOP to exit") print("================") while not stop_event.is_set(): cmd = session.prompt("> ") if cmd.strip() == "STOP": stop_event.set() print( "Stopping the streaming sequential toolset..." ) else: print( "Unrecognized input. " "Please type 'STOP' to end." ) visualization_data = {} for ( category_name, data_pack, ) in self.visualization_data.items(): for var_name, var_pack in data_pack.items(): if "buffer" in var_pack: visualization_data[category_name] = { var_name: var_pack["buffer"] } with open("visualization_data.json", "w") as f: json.dump(visualization_data, f, indent=2) input_thread = threading.Thread(target=input_loop) input_thread.start() fit_thread = threading.Thread(target=stream_loop) fit_thread.start() while not stop_event.is_set(): self._update_plot() plt.pause(0.01) time.sleep(1) fit_thread.join() input_thread.join() else: raise ValueError(f"Unknown mode: {mode}")