LCOV - code coverage report
Current view: top level - hybrid - store_load_hybrid.F90 (source / functions) Hit Total Coverage
Test: FLEUR test coverage Lines: 125 206 60.7 %
Date: 2024-05-02 04:21:52 Functions: 9 12 75.0 %

          Line data    Source code
       1             : module m_store_load_hybrid
       2             : #ifdef CPP_HDF
       3             :    USE hdf5
       4             : #endif
       5             :    use m_juDFT
       6             :    use m_types
       7             :    use m_mpi_bc_tool
       8             :    use m_juDFT
       9             :    use m_types_mpimat
      10             :    use m_distrib_vx
      11             : 
      12             :    character(len=*), parameter :: hybstore_fname = "hybrid.h5"
      13             :    public store_hybrid_data, load_hybrid_data
      14             : #ifdef CPP_HDF
      15             :    private open_file, open_datasetr, write_int_2d, close_dataset, close_file
      16             : #endif
      17             : contains
      18           6 :    subroutine load_hybrid_data(fi, fmpi, hybdat, mpdata)
      19             :       use m_constants
      20             :       use m_mixing_history
      21             :       implicit none
      22             :       type(t_fleurinput), intent(in)     :: fi
      23             :       type(t_mpi), intent(in)            :: fmpi
      24             :       type(t_hybdat), intent(inout)      :: hybdat
      25             :       type(t_mpdata), intent(inout)      :: mpdata
      26             : 
      27             :       logical :: l_exist
      28           6 :       integer, allocatable :: dims(:)
      29           6 :       character(len=:), allocatable :: dset_name
      30             :       integer                       :: ierr, nk, jsp
      31           6 :       real, allocatable             :: tmp(:, :)
      32           6 :       type(t_mat)                   :: vx_tmp
      33             : 
      34             : #ifdef CPP_HDF
      35             :       integer(HID_T)   :: dset_id
      36             :       INTEGER(HID_T)   :: file_id
      37             : 
      38             : 
      39           6 :       if(.not. hybdat%l_subvxc) then
      40           6 :          call timestart("load_hybrid_data")
      41           6 :          if (fmpi%is_root()) INQUIRE (file='hybrid.h5', exist=l_exist)
      42           6 :          call mpi_bc(l_exist, 0, fmpi%mpi_comm)
      43             : 
      44             : 
      45           6 :          if(l_exist .and. (.not. allocated(hybdat%v_x))) then
      46             : #ifdef CPP_MPI
      47           0 :             call MPI_Barrier(fmpi%mpi_comm, ierr)
      48             : #endif
      49           0 :             call mixing_history_reset(fmpi)
      50             : 
      51           0 :             IF (fmpi%n_size == 1) THEN
      52           0 :                ALLOCATE (t_mat::hybdat%v_x(fi%kpts%nkpt, fi%input%jspins))
      53             :             ELSE
      54           0 :                ALLOCATE (t_mpimat::hybdat%v_x(fi%kpts%nkpt, fi%input%jspins))
      55             :             END IF
      56             : 
      57           0 :             if (fmpi%is_root()) then
      58           0 :                call timestart("read part")
      59           0 :                file_id = open_file()
      60             : 
      61           0 :                dset_id = open_dataset(file_id, "nbands")
      62           0 :                if (.not. allocated(hybdat%nbands)) allocate (hybdat%nbands(fi%kpts%nkptf, fi%input%jspins))
      63           0 :                call read_int_2d(dset_id, hybdat%nbands)
      64           0 :                call close_dataset(dset_id)
      65             : 
      66           0 :                dset_id = open_dataset(file_id, "nobd")
      67           0 :                if (.not. allocated(hybdat%nobd)) allocate (hybdat%nobd(fi%kpts%nkptf, fi%input%jspins))
      68           0 :                call read_int_2d(dset_id, hybdat%nobd)
      69           0 :                call close_dataset(dset_id)
      70           0 :                call timestop("read part")
      71             :             end if
      72             : 
      73           0 :             do jsp = 1, fi%input%jspins
      74           0 :                do nk = 1, fi%kpts%nkpt
      75           0 :                   if(fmpi%is_root()) then
      76           0 :                      call timestart("read part")
      77           0 :                      if (fi%sym%invs) then
      78           0 :                         dset_name = "vx_nk="//int2str(nk)//"_jsp="//int2str(jsp)
      79           0 :                         dset_id = open_dataset(file_id, dset_name)
      80           0 :                         dims = get_dims(dset_id)
      81           0 :                         call vx_tmp%alloc(fi%sym%invs, dims(1), dims(2))
      82           0 :                         call read_dbl_2d(dset_id, vx_tmp%data_r)
      83           0 :                         call close_dataset(dset_id)
      84             :                      else
      85           0 :                         dset_name = "r_vx_nk="//int2str(nk)//"_jsp="//int2str(jsp)
      86           0 :                         dset_id = open_dataset(file_id, dset_name)
      87             : 
      88             :                         ! get dimensions and alloc space
      89           0 :                         dims = get_dims(dset_id)
      90           0 :                         call vx_tmp%alloc(fi%sym%invs, dims(1), dims(2))
      91           0 :                         allocate (tmp(dims(1), dims(2)), stat=ierr)
      92           0 :                         if (ierr /= 0) call juDFT_error("can't alloc tmp")
      93             : 
      94             :                         ! get real part
      95           0 :                         call read_dbl_2d(dset_id, tmp)
      96           0 :                         vx_tmp%data_c = tmp
      97           0 :                         call close_dataset(dset_id)
      98             : 
      99             :                         ! get complex part
     100           0 :                         dset_name = "c_vx_nk="//int2str(nk)//"_jsp="//int2str(jsp)
     101           0 :                         dset_id = open_dataset(file_id, dset_name)
     102           0 :                         call read_dbl_2d(dset_id, tmp)
     103           0 :                         vx_tmp%data_c = vx_tmp%data_c + ImagUnit*tmp
     104           0 :                         call close_dataset(dset_id)
     105           0 :                         deallocate (tmp)
     106             :                      end if
     107           0 :                      call timestop("read part")
     108             :                   endif
     109             : 
     110           0 :                   call mpi_bc(dims, 0, fmpi%mpi_comm)
     111           0 :                   call distrib_single_vx(fi, fmpi, jsp, nk, 0, vx_tmp, hybdat, dims=dims)
     112           0 :                   call vx_tmp%free()
     113             :                end do
     114             :             end do
     115             : 
     116           0 :             call timestart("bcast part")
     117           0 :             call mpi_bc(hybdat%nbands, 0, fmpi%mpi_comm)
     118           0 :             call mpi_bc(hybdat%nobd, 0, fmpi%mpi_comm)      
     119           0 :             call timestop("bcast part")
     120             : 
     121           0 :             call mpdata%set_num_radfun_per_l(fi%atoms)
     122           0 :             call hybdat%set_maxlmindx(fi%atoms, mpdata%num_radfun_per_l)
     123             : 
     124           0 :             hybdat%l_addhf = .True.
     125           0 :             hybdat%l_subvxc = .True.
     126             :          end if
     127           6 :          call timestop("load_hybrid_data")
     128             :       endif
     129             : #endif
     130           6 :    end subroutine load_hybrid_data
     131             : 
     132          12 :    subroutine store_hybrid_data(fi, fmpi, hybdat)
     133             :       implicit none
     134             :       type(t_fleurinput), intent(in)     :: fi
     135             :       type(t_mpi), intent(in)            :: fmpi
     136             :       type(t_hybdat), intent(in)         :: hybdat
     137             : 
     138             :       integer                       :: error, nk, jsp
     139          12 :       character(len=:), allocatable :: dset_name
     140          12 :       type(t_mat) :: vx_tmp
     141             : #ifdef CPP_HDF
     142             :       integer(HID_T)   :: dset_id
     143             :       INTEGER(HID_T)   :: file_id
     144             : 
     145          12 :       call timestart("store_hybrid_data")
     146             : 
     147          12 :       if(fmpi%irank == 0) then
     148           6 :          file_id = open_file()
     149             : 
     150          18 :          dset_id = open_dataset(file_id, "nbands", [fi%kpts%nkptf, fi%input%jspins], H5T_NATIVE_INTEGER)
     151           6 :          call write_int_2d(dset_id, hybdat%nbands)
     152           6 :          call close_dataset(dset_id)
     153             : 
     154          18 :          dset_id = open_dataset(file_id, "nobd", [fi%kpts%nkptf, fi%input%jspins], H5T_NATIVE_INTEGER)
     155           6 :          call write_int_2d(dset_id, hybdat%nobd)
     156           6 :          call close_dataset(dset_id)
     157             : 
     158          18 :          dset_id = open_dataset(file_id, "bkf", [3, fi%kpts%nkptf], H5T_NATIVE_DOUBLE)
     159           6 :          call write_dbl_2d(dset_id, fi%kpts%bkf)
     160           6 :          call close_dataset(dset_id)
     161             : 
     162          18 :          dset_id = open_dataset(file_id, "bkp", [fi%kpts%nkptf, 1], H5T_NATIVE_INTEGER)
     163           6 :          call write_int_1d(dset_id, fi%kpts%bkp)
     164          12 :          call close_dataset(dset_id)
     165             :       endif
     166             : 
     167             :       ! hdf5 only knows reals
     168          28 :       do jsp = 1, fi%input%jspins
     169          76 :          do nk = 1, fi%kpts%nkpt
     170          48 :             call collect_vx(fi, fmpi, hybdat, nk, jsp, vx_tmp)
     171          64 :             if(fmpi%irank == 0 ) then
     172          24 :                if (fi%sym%invs) then
     173          18 :                   dset_name = "vx_nk="//int2str(nk)//"_jsp="//int2str(jsp)
     174          54 :                   dset_id = open_dataset(file_id, dset_name, shape(vx_tmp%data_r), H5T_NATIVE_DOUBLE)
     175          18 :                   call write_dbl_2d(dset_id, vx_tmp%data_r)
     176          24 :                   call close_dataset(dset_id)
     177             :                else
     178           6 :                   dset_name = "r_vx_nk="//int2str(nk)//"_jsp="//int2str(jsp)
     179          18 :                   dset_id = open_dataset(file_id, dset_name, shape(vx_tmp%data_c), H5T_NATIVE_DOUBLE)
     180      196250 :                   call write_dbl_2d(dset_id, real(vx_tmp%data_c))
     181           6 :                   call close_dataset(dset_id)
     182             : 
     183           6 :                   dset_name = "c_vx_nk="//int2str(nk)//"_jsp="//int2str(jsp)
     184          18 :                   dset_id = open_dataset(file_id, dset_name, shape(vx_tmp%data_c), H5T_NATIVE_DOUBLE)
     185      196250 :                   call write_dbl_2d(dset_id, aimag(vx_tmp%data_c))
     186          24 :                   call close_dataset(dset_id)
     187             :                end if
     188          24 :                call vx_tmp%free()
     189             :             endif
     190             :          end do
     191             :       end do
     192             : 
     193          12 :       if(fmpi%irank == 0) call close_file(file_id)
     194          12 :       call timestop("store_hybrid_data")
     195             : #endif
     196          12 :    end subroutine store_hybrid_data
     197             : 
     198          48 :    subroutine collect_vx(fi, fmpi, hybdat, nk, jsp, vx_tmp)
     199             :       use m_glob_tofrom_loc
     200             :       implicit none 
     201             :       type(t_fleurinput), intent(in)     :: fi
     202             :       type(t_mpi), intent(in)            :: fmpi
     203             :       type(t_hybdat), intent(in)         :: hybdat
     204             :       integer, intent(in)                :: nk, jsp 
     205             :       type(t_mat), intent(inout)         :: vx_tmp
     206             :       
     207             :       integer, parameter :: recver = 0 ! HDF is node on global root
     208             :       integer :: sender, ierr, buff(2), i, pe_i, i_loc
     209             :       logical :: l_mpimat
     210             : 
     211             :       ! find out and bcast what kind of matrix we are using
     212          96 :       sender = merge(fmpi%irank, -1, any(fmpi%k_list == nk) .and. fmpi%n_rank == 0)
     213             : #ifdef CPP_MPI
     214          48 :       call MPI_Allreduce(MPI_IN_PLACE, sender, 1, MPI_INTEGER, MPI_MAX, fmpi%mpi_comm, ierr)
     215             : #endif
     216             : 
     217             :       select type(vx_origin => hybdat%v_x(nk,jsp)) 
     218             :       class is(t_mpimat) 
     219          48 :          if(sender == fmpi%irank) then
     220          72 :             buff = [vx_origin%global_size1, vx_origin%global_size2]
     221             :          endif
     222             :       class is(t_mat)
     223           0 :          if(sender == fmpi%irank) then
     224           0 :             buff = [vx_origin%matsize1, vx_origin%matsize2]
     225             :          endif
     226             :       end select
     227             : #ifdef CPP_MPI
     228          48 :       call MPI_Bcast(buff, 2, MPI_INTEGER, sender, fmpi%mpi_comm, ierr)
     229             : #endif
     230             : 
     231          48 :       if(fmpi%irank == recver) then
     232          24 :          call vx_tmp%init(fi%sym%invs, buff(1), buff(2))
     233             :       endif
     234             : 
     235        7256 :       do i = 1, buff(2)
     236        7208 :          call glob_to_loc(fmpi, i, pe_i, i_loc)
     237       14424 :          sender = merge(fmpi%irank, -1, pe_i == fmpi%n_rank .and. any(fmpi%k_list == nk))
     238             : #ifdef CPP_MPI
     239        7208 :          call MPI_Allreduce(MPI_IN_PLACE, sender, 1, MPI_INTEGER, MPI_MAX, fmpi%mpi_comm, ierr)
     240             : #endif
     241             : 
     242        7256 :          if(sender == recver .and. fmpi%irank == recver) then
     243        1806 :             if(vx_tmp%l_real) then
     244      231144 :                vx_tmp%data_r(:,i) = hybdat%v_x(nk,jsp)%data_r(:,i_loc)
     245             :             else
     246       98302 :                vx_tmp%data_c(:,i) = hybdat%v_x(nk,jsp)%data_c(:,i_loc)
     247             :             endif
     248             : #ifdef CPP_MPI
     249        5402 :          elseif(sender == fmpi%irank) then
     250        1798 :             if(fi%sym%invs) then
     251        1258 :                call MPI_Send(hybdat%v_x(nk,jsp)%data_r(:,i_loc), buff(1), MPI_DOUBLE_PRECISION, recver, 100+i, fmpi%mpi_comm,ierr)
     252             :             else
     253         540 :                call MPI_Send(hybdat%v_x(nk,jsp)%data_c(:,i_loc), buff(1), MPI_DOUBLE_COMPLEX, recver, 100+i, fmpi%mpi_comm,ierr)
     254             :             endif
     255        3604 :          elseif(fmpi%irank == recver) then
     256        1798 :             if(vx_tmp%l_real) then
     257        1258 :                call MPI_Recv(vx_tmp%data_r(:,i), buff(1), MPI_DOUBLE_PRECISION, sender, 100+i, fmpi%mpi_comm, MPI_STATUS_IGNORE, ierr)
     258             :             else
     259         540 :                call MPI_Recv(vx_tmp%data_c(:,i), buff(1), MPI_DOUBLE_COMPLEX, sender, 100+i, fmpi%mpi_comm, MPI_STATUS_IGNORE, ierr)
     260             :             endif
     261             : #endif
     262             :          endif
     263             :       enddo
     264          48 :    end subroutine collect_vx
     265             : 
     266             : #ifdef CPP_HDF
     267           6 :    function open_file() result(file_id)
     268             :       implicit none
     269             :       integer(HID_T) :: file_id
     270             : 
     271             :       logical :: file_exist
     272             :       integer :: error
     273             : 
     274           6 :       INQUIRE (file='hybrid.h5', exist=file_exist)
     275             : 
     276           6 :       if (file_exist) then
     277           3 :          CALL h5fopen_f(hybstore_fname, H5F_ACC_RDWR_F, file_id, error)
     278           3 :          if (error /= 0) call juDFT_error("cant't open hdf5 file")
     279             :       else
     280           3 :          CALL h5fcreate_f(hybstore_fname, H5F_ACC_TRUNC_F, file_id, error)
     281           3 :          if (error /= 0) call juDFT_error("cant't create hdf5 file")
     282             :       end if
     283           6 :    end function open_file
     284             : 
     285           6 :    subroutine close_file(file_id)
     286             :       implicit none
     287             :       integer(HID_T), intent(in) :: file_id
     288             : 
     289             :       integer :: ierr
     290             : 
     291           6 :       CALL h5fclose_f(file_id, ierr)
     292           6 :       if (ierr /= 0) call juDFT_error("can't close hdf5 file")
     293           6 :    end subroutine close_file
     294             : 
     295          54 :    function open_dataset(file_id, dsetname, in_dims, type_id) result(dset_id)
     296             :       implicit NONE
     297             :       integer(HID_T), intent(in)           :: file_id
     298             :       character(len=*), intent(in)         :: dsetname
     299             :       integer, optional, intent(in)        :: in_dims(:)
     300             :       integer(HID_T), intent(in), optional :: type_id
     301             : 
     302             :       INTEGER(HID_T) :: dset_id
     303             :       integer :: ierr
     304             :       logical :: dset_exists
     305             :       integer(HID_T)   :: dspace_id
     306          54 :       INTEGER(HSIZE_T), allocatable :: dims(:)
     307             : 
     308          54 :       call h5lexists_f(file_id, dsetname, dset_exists, ierr)
     309          54 :       if (ierr /= 0) call juDFT_error("Can't check if dataset exists")
     310             : 
     311          54 :       if (dset_exists) then
     312          27 :          call h5dopen_f(file_id, dsetname, dset_id, ierr)
     313             :       else
     314          27 :          if (present(in_dims)) then
     315          81 :             allocate (dims(size(in_dims)))
     316         108 :             dims = in_dims
     317             :          else
     318           0 :             call juDFT_error("dims needed for file creation")
     319             :          end if
     320             : 
     321          27 :          CALL h5screate_simple_f(2, dims, dspace_id, ierr)
     322          27 :          if (ierr /= 0) call juDFT_error("can't create dataspace")
     323             : 
     324          27 :          if (.not. present(type_id)) call juDFT_error("type_id needed for dataset creation")
     325          27 :          CALL h5dcreate_f(file_id, dsetname, type_id, dspace_id, dset_id, ierr)
     326          27 :          if (ierr /= 0) call juDFT_error("creating data set failed")
     327             : 
     328          27 :          CALL h5sclose_f(dspace_id, ierr)
     329          27 :          if (ierr /= 0) call juDFT_error("can't close dataspace")
     330             :       end if
     331          54 :    end function open_dataset
     332             : 
     333          54 :    subroutine close_dataset(dset_id)
     334             :       implicit none
     335             :       INTEGER(HID_T), intent(in)         :: dset_id
     336             : 
     337             :       integer :: ierr
     338             : 
     339          54 :       CALL h5dclose_f(dset_id, ierr)
     340          30 :    end subroutine close_dataset
     341             : 
     342           6 :    subroutine write_int_1d(dset_id, mtx)
     343             :       implicit none
     344             :       INTEGER(HID_T), intent(in)         :: dset_id
     345             :       integer, intent(in)                :: mtx(:)
     346             : 
     347             :       INTEGER(HSIZE_T), DIMENSION(2)     :: data_dims
     348             :       integer                            :: ierr
     349             : 
     350           6 :       data_dims(1) = size(mtx)
     351           6 :       data_dims(2) = 1
     352             : 
     353           6 :       CALL h5dwrite_f(dset_id, H5T_NATIVE_INTEGER, mtx, data_dims, ierr)
     354           6 :       if (ierr /= 0) call juDFT_error("can't write int 1d")
     355           6 :    end subroutine write_int_1d
     356             : 
     357          12 :    subroutine write_int_2d(dset_id, mtx)
     358             :       implicit none
     359             :       INTEGER(HID_T), intent(in)         :: dset_id
     360             :       integer, intent(in)                :: mtx(:, :)
     361             : 
     362             :       INTEGER(HSIZE_T), DIMENSION(2)     :: data_dims
     363             :       integer                            :: ierr
     364             : 
     365          36 :       data_dims = shape(mtx)
     366          12 :       CALL h5dwrite_f(dset_id, H5T_NATIVE_INTEGER, mtx, data_dims, ierr)
     367          12 :       if (ierr /= 0) call juDFT_error("can't write int 2d")
     368          12 :    end subroutine write_int_2d
     369             : 
     370           0 :    subroutine read_int_2d(dset_id, mtx)
     371             :       implicit none
     372             :       INTEGER(HID_T), intent(in)         :: dset_id
     373             :       integer, intent(inout)             :: mtx(:, :)
     374             : 
     375             :       INTEGER(HSIZE_T), DIMENSION(2)     :: data_dims
     376             :       integer                            :: ierr
     377             : 
     378           0 :       data_dims = shape(mtx)
     379           0 :       CALL h5dread_f(dset_id, H5T_NATIVE_INTEGER, mtx, data_dims, ierr)
     380           0 :       if (ierr /= 0) call juDFT_error("can't read int 2d")
     381           0 :    end subroutine read_int_2d
     382             : 
     383          36 :    subroutine write_dbl_2d(dset_id, mtx)
     384             :       implicit none
     385             :       INTEGER(HID_T), intent(in)         :: dset_id
     386             :       real, intent(in)                   :: mtx(:, :)
     387             : 
     388             :       INTEGER(HSIZE_T), DIMENSION(2)     :: data_dims
     389             :       integer                            :: ierr
     390             : 
     391         108 :       data_dims = shape(mtx)
     392          36 :       CALL h5dwrite_f(dset_id, H5T_NATIVE_DOUBLE, mtx, data_dims, ierr)
     393          36 :       if (ierr /= 0) call juDFT_error("can't write 2d real mtx")
     394          36 :    end subroutine write_dbl_2d
     395             : 
     396           0 :    subroutine read_dbl_2d(dset_id, mtx)
     397             :       implicit none
     398             :       INTEGER(HID_T), intent(in)     :: dset_id
     399             :       real, intent(inout)            :: mtx(:, :)
     400             : 
     401             :       INTEGER(HSIZE_T), DIMENSION(2)     :: data_dims
     402             :       integer                            :: ierr
     403             : 
     404           0 :       data_dims = shape(mtx)
     405           0 :       CALL h5dread_f(dset_id, H5T_NATIVE_DOUBLE, mtx, data_dims, ierr)
     406           0 :       if (ierr /= 0) call juDFT_error("can't read int 2d")
     407           0 :    end subroutine read_dbl_2d
     408             : 
     409           0 :    function get_dims(dset_id) result(dims)
     410             :       implicit none
     411             :       INTEGER(HID_T), intent(in)         :: dset_id
     412             :       integer, allocatable               :: dims(:)
     413             : 
     414           0 :       integer(HSIZE_T), allocatable      :: hdims(:), hmaxdims(:)
     415             :       integer(HID_T) :: dataspace_id
     416             :       integer        :: ierr, ndims
     417             : 
     418           0 :       call h5dget_space_f(dset_id, dataspace_id, ierr)
     419           0 :       if (ierr /= 0) call juDFT_error("can't get dataspace")
     420             : 
     421           0 :       call h5sget_simple_extent_ndims_f(dataspace_id, ndims, ierr)
     422           0 :       if (ierr /= 0) call juDFT_error("can't get ndims")
     423             : 
     424           0 :       allocate (hdims(ndims), hmaxdims(ndims))
     425             : 
     426           0 :       call h5sget_simple_extent_dims_f(dataspace_id, hdims, hmaxdims, ierr)
     427           0 :       if (ierr /= ndims) call juDFT_error("can't get dims")
     428             : 
     429           0 :       dims = hdims
     430             : 
     431           0 :       call h5sclose_f(dataspace_id, ierr)
     432           0 :       if (ierr /= 0) call juDFT_error("can't close dataspace")
     433           0 :    end function get_dims
     434             : 
     435             : #endif
     436          48 : end module m_store_load_hybrid

Generated by: LCOV version 1.14